Introduction to EOFS

Introduction to EOFS#

#!wget https://downloads.psl.noaa.gov/Datasets/COBE/sst.mon.mean.nc
import numpy as np
from scipy import signal
import numpy.polynomial.polynomial as poly

import matplotlib.pyplot as plt
from eofs.xarray import Eof
import cartopy.crs as ccrs
import cartopy.feature as cfeature
n_samples = 1000

x1 = np.random.normal(0, 2, n_samples)
x2 = 0.8 * x1 + np.random.normal(0, 1, n_samples)
data = np.vstack((x1, x2)).T  # Combine into a 2D array

ds = xr.DataArray(data, dims=("time", "space"), coords={"space": ["X1", "X2"]})

fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(x1, x2, alpha=0.6, label="Data")
plt.grid(True)
plt.show()
../_images/1136c3a41199312703d4b9ef888917c91751d949beb26c0f090309e1cff23b18.png
solver = Eof(ds)
eofs = solver.eofs(neofs=2)  # Get EOF spatial patterns
pcs = solver.pcs(npcs=2, pcscaling=1)  # Get principal components
plt.quiver(0, 0, eofs.isel(mode=0)[0], eofs.isel(mode=0)[1])
plt.quiver(0, 0, eofs.isel(mode=1)[0], eofs.isel(mode=1)[1])
<matplotlib.quiver.Quiver at 0x7faf068e0150>
../_images/17c16b95257cf0912f1205bcb805163670946b84cd6dfce7cf40a32cc2b1f893.png
mean_x1, mean_x2 = np.mean(x1), np.mean(x2)
scaling = 5  #

fig, ax = plt.subplots(figsize=(8, 6))
plt.scatter(x1, x2, alpha=0.6, label="Data")
plt.grid(True)

plt.quiver(mean_x1, mean_x2, eofs.sel(space="X1")[0] * scaling, eofs.sel(space="X2")[0] * scaling,
           color="red", scale_units="xy", scale=1, label="EOF 1")

# EOF2: Second dominant axis
plt.quiver(mean_x1, mean_x2, eofs.sel(space="X1")[1] * scaling, eofs.sel(space="X2")[1] * scaling,
           color="blue", scale_units="xy", scale=1, label="EOF 2")

# Plot styling
plt.axhline(0, color="grey", lw=0.5)
plt.axvline(0, color="grey", lw=0.5)
plt.xlabel("X1")
plt.ylabel("X2")
plt.title("Scatter Plot with EOF Axes (using eofs)")
plt.legend()
<matplotlib.legend.Legend at 0x7faf1019dc50>
../_images/92f4154a689e857a8c0c6ad0046776a8b8c0248d51b41e41edff4313f0f5fa70.png

How did we do that?

data.shape # is already (time, space)
(1000, 2)
data_mean = np.mean(data, axis=0)  
data_anomalies = data - data_mean 
cov_matrix = np.cov(data_anomalies, rowvar=False)  # Covariance across space
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)  # Solve for eigenvalues/vectors
sorted_indices = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[sorted_indices]
eigenvectors = -1*eigenvectors[:, sorted_indices]
pcs = np.dot(data_anomalies, eigenvectors)
plt.quiver(0, 0, eigenvectors[0][0],eigenvectors[0][1], color="red",alpha=.5)
plt.quiver(0, 0, eigenvectors[1][0],eigenvectors[1][1], color="red",alpha=.5)
<matplotlib.quiver.Quiver at 0x7faf0694c1d0>
../_images/8286b2d2aa626c9aaded55c20e898a7eaadff3d11fd2491a40a7397032e64983.png
plt.quiver(0, 0, eigenvectors[0][0],eigenvectors[0][1], color="red",alpha=.5)
plt.quiver(0, 0, eigenvectors[1][0],eigenvectors[1][1], color="red",alpha=.5)
plt.quiver(0, 0, eofs.isel(mode=0)[0], eofs.isel(mode=0)[1],alpha=0.5)
plt.quiver(0, 0, eofs.isel(mode=1)[0], eofs.isel(mode=1)[1],alpha=.5)
<matplotlib.quiver.Quiver at 0x7faf0680d990>
../_images/d893d98cabbcd11a412ccf557cc43a55e2cccd6b8b0bcce41bc484e681d5568b.png
infile = "sst.mon.mean.nc"
dataset = xr.open_dataset(infile)
dataset.isel(time = 0).sst.plot()
<matplotlib.collections.QuadMesh at 0x7faf0684f050>
../_images/e2957d17426240e23553a2968b78c3cb3778664cbb99eedf932795de721b7fc9.png
def detrend_dim(da, dim, deg=1):
    """
    Detrend data along a specified dimension.

    This function removes a polynomial trend along a given dimension by fitting 
    and subtracting a polynomial of specified degree.

    Parameters
    ----------
    da : xr.DataArray
        The data array to detrend.
    dim : str
        The dimension along which to detrend the data, typically representing time.
    deg : int, optional
        The degree of the polynomial for detrending. Default is 1 (linear detrend).

    Returns
    -------
    xr.DataArray
        The detrended data array with the fitted trend removed along the specified dimension.
    """

    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit
detrendedSST = detrend_dim(dataset['sst'], dim="time")
monthly_mean = detrendedSST.groupby('time.month').mean('time')
sst_deseasonalized = detrendedSST.groupby('time.month') - monthly_mean
sst_deseasonalized = sst_deseasonalized.drop("month")
coslat = np.cos(np.deg2rad(dataset.lat.data))
wgts = np.sqrt(coslat)[..., np.newaxis]

solver = Eof(sst_deseasonalized, weights=wgts)
eof1 = solver.eofs(neofs=10)
pc1  = solver.pcs(npcs=10, pcscaling=0)
varfrac = solver.varianceFraction()
lambdas = solver.eigenvalues()
# Define parallels and meridians
parallels = np.arange(-90, 90, 30.)
meridians = np.arange(-180, 180, 30.)

# Define lon and lat
lon = dataset.lon
lat = dataset.lat

for i in range(0, 1):
    fig = plt.figure(figsize=(12, 12))  # Larger figure for better clarity

    # --- Plot EOF Field ---
    ax1 = plt.subplot(2, 1, 1, projection=ccrs.Robinson(central_longitude=180))
    
    # Plot EOF with proper transformation
    cs = eof1.isel(mode=i).plot(
        ax=ax1, 
        cmap=plt.cm.RdBu, 
        transform=ccrs.PlateCarree(),
        add_colorbar=False  # Avoid duplicate colorbars
    )
    
    # Add coastlines and gridlines
    ax1.coastlines()
    gl = ax1.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    gl.left_labels = True
    gl.bottom_labels = True

    # Add colorbar manually
    cb = plt.colorbar(cs, ax=ax1, orientation='vertical', fraction=0.05, pad=0.05)
    cb.set_label('EOF', fontsize=12)

    # Title
    ax1.set_title(f'EOF {i + 1}', fontsize=16)

    # --- Plot PC Time Series ---
    ax2 = plt.subplot(2, 1, 2)
    pc1.isel(mode=i).plot(ax=ax2, linewidth=2, color='blue')
    ax2.axhline(0, color='k', linewidth=0.8)
    ax2.set_xlabel('Year', fontsize=12)
    ax2.set_ylabel('PC Amplitude', fontsize=12)
    ax2.set_ylim(np.min(pc1), np.max(pc1))
    ax2.set_title(f'PC Time Series - Mode {i + 1}', fontsize=14)
    ax1.set_aspect("auto")
../_images/45374f3093d34af030438d25e508c7f00170d19d19dcdd37a75c2f5e198ffe74.png
plt.plot(varfrac.isel(mode=slice(0,10)), ls="", marker=".")
[<matplotlib.lines.Line2D at 0x7faf06aa00d0>]
../_images/01199a8840ffada3df154698a236d5c2b08eec8b2ecf9c34e9b5dfff262f2925.png
slp = xr.open_dataset("/students_files/data/SLP.nc").msl
slp
<xarray.DataArray 'msl' (time: 1320, latitude: 181, longitude: 360)>
[86011200 values with dtype=float32]
Coordinates:
  * longitude  (longitude) float32 0.0 1.0 2.0 3.0 ... 356.0 357.0 358.0 359.0
  * latitude   (latitude) float32 90.0 89.0 88.0 87.0 ... -88.0 -89.0 -90.0
  * time       (time) datetime64[ns] 1900-01-01 1900-02-01 ... 2009-12-01
Attributes:
    units:          Pa
    long_name:      Mean sea level pressure
    standard_name:  air_pressure_at_mean_sea_level
#| export
def adjust_lon_lat(ds, lon_name, lat_name, reverse = False):
    """Adjusts longitude from 0 to 360 to -180 to 180 and reverses latitude."""

    if reverse == True:
        ds = ds.reindex({lat_name:ds[lat_name][::-1]})

    ds['_longitude_adjusted'] = xr.where(
        ds[lon_name] > 180,
        ds[lon_name] - 360,
        ds[lon_name])

    ds = (ds
          .swap_dims({lon_name: '_longitude_adjusted'})
          .sel(**{'_longitude_adjusted': sorted(ds._longitude_adjusted)})
          .drop(lon_name))

    ds = ds.rename({'_longitude_adjusted': lon_name})

    return ds
slp = adjust_lon_lat(slp, lon_name="longitude", lat_name="latitude", reverse=True)
slp.isel(time=0).plot()
<matplotlib.collections.QuadMesh at 0x7faeffafc790>
../_images/899c770cb31b65c82f1d7cd41c8a87f1643cc343ee24cc2d4bd30cf0c0e71f6f.png
slp_djf
<xarray.DataArray 'msl' (time: 330, latitude: 0, longitude: 41)>
array([], shape=(330, 0, 41), dtype=float32)
Coordinates:
  * longitude  (longitude) float32 0.0 1.0 2.0 3.0 4.0 ... 37.0 38.0 39.0 40.0
  * latitude   (latitude) float32 
  * time       (time) datetime64[ns] 1900-01-01 1900-02-01 ... 2009-12-01
slp_djf = slp.sel(time=slp['time'].dt.month.isin([12, 1, 2]))
slp_djf = slp_djf.sel(longitude=slice(-80,40), latitude=slice(20,90))
# Compute anomalies by removing the time-mean.
slp_djf = slp_djf - slp_djf.mean(dim='time')

# Create an EOF solver to do the EOF analysis. Square-root of cosine of
# latitude weights are applied before the computation of EOFs.
coslat = np.cos(np.deg2rad(slp_djf.coords['latitude'].values)).clip(0., 1.)
wgts = np.sqrt(coslat)[..., np.newaxis]
solver = Eof(slp_djf, weights=wgts)

# Retrieve the leading EOF, expressed as the covariance between the leading PC
# time series and the input SLP anomalies at each grid point.
eof1 = solver.eofsAsCovariance(neofs=1)

# Plot the leading EOF expressed as covariance in the European/Atlantic domain.
proj = ccrs.Orthographic(central_longitude=-20, central_latitude=60)
ax = plt.axes(projection=proj)
ax.coastlines()
ax.set_global()
eof1.isel(mode=0).plot.pcolormesh(ax=ax, cmap=plt.cm.RdBu_r,
                         transform=ccrs.PlateCarree(), add_colorbar=False)
ax.set_title('EOF1', fontsize=16)
plt.show()
../_images/7ea476fcfe27f826a31eb6b53798257715370d357c21eeae9c4aa0bf9a660e70.png