"""
Compute the coherence of two signals
"""
import numpy as np
import matplotlib.pyplot as plt

# make a little extra space between the subplots
plt.subplots_adjust(wspace=0.5)

dt = 0.01
Nb_t = 3000
tMax = Nb_t*dt
t = np.arange(0, tMax, dt)
nse0 = np.random.randn(len(t))                 # white noise 0
nse1 = np.random.randn(len(t))                 # white noise 1
nse2 = np.random.randn(len(t))                 # white noise 2
# convolve with damped oscillator
tau = 0.09*tMax
omega = 2*np.pi*15
print('damped osc filter with omega tau / 2pi = %3.4g' %(omega*tau/(2*np.pi)))
r = np.exp(-t/tau)*np.cos(omega*t)

A0 = 0.2
A1 = np.sqrt(1.-A0**2)
cnse1 = np.convolve(A0*nse0 + A1*nse1, r, mode='same')*dt   # colored noise 1
cnse2 = np.convolve(A0*nse0 + A1*nse2, r, mode='same')*dt   # colored noise 2

# two signals with a coherent part and a random part
A_coh = 0.01 # 0.01
s1 = A_coh*np.sin(2*np.pi*10*t) + cnse1
s2 = A_coh*np.sin(2*np.pi*10*t) + cnse2

plt.subplot(311)
plt.plot(t, s1, t, s2)
plt.xlim(0, 5)
plt.xlabel('time')
plt.ylabel('s1 and s2')
# plt.grid(True)

plt.subplot(312)
cxx, f = plt.psd(s1, 256, 1./dt, label = r'$S_{1}$')
cyy, f = plt.psd(s2, 256, 1./dt, label = r'$S_{2}$')
plt.ylabel('PS')
plt.legend()

plt.subplot(313)
cxy, f = plt.cohere(s1, s2, 256, 1./dt, label = r'$S_{12}$')
plt.ylabel('CCS')

plt.show(block=False)