# EGGPreprocessor
:::success
**Usage** Multiple
```preprocessing``` ```visualization```
**Code type** Scripts
**Laboratory instrumentation(s)**
```EGG``` ```Acoustics```
**Programming language(s)**
```Python```
**Contributor(s)** Benson
:::
## README
EGGPreprocessor allows for:
- smoothing
- high-pass filtering
- Calculation of DEGG
- Raw data visualization
See comments in the script for more information.
## Installation
The ```py``` file is on the server.
Type the following on your terminal to download the file, or just use FileZilla/VS Code.
```
scp {yourusername}@140.112.147.111:../../mnt/hdd/codes/egg_preprocessor.py {your local dir}
```
## [Demo](https://drive.google.com/file/d/1CGPReR4iG0i1iV41gstj40EfnwMN6AhC/view?usp=share_link)
<iframe
src="https://peh-suan.github.io/others/egg_preprocessor_example.html"
style="width:100%; height:300px;"
></iframe>
## Code
```python=
import tgt # for importing TextGrid
from scipy.io import wavfile # for reading wav files
import numpy as np # for mathematical processing
from tqdm import tqdm # shows progress bar
import seaborn as sns # for plotting
import matplotlib.pyplot as plt # for plotting
plt.rcParams.update({"lines.linewidth": .3})
sns.set_palette("pastel")
class EGGPreprocessor():
def __init__(self, data_files, textgrid_file=""):
self.data = {}
self.tgrd = None
if textgrid_file:
self.add_textgrid(textgrid_file)
for data_name in data_files:
self.sampling_frequency, self.data[data_name] = wavfile.read(data_files[data_name])
self.data_len = len(self.data[list(self.data.keys())[0]])
self.time = np.arange(0, self.data_len)/self.sampling_frequency
self.interval_data = {}
def change_data_name(self, original_name, target_name):
assert original_name in self.data, f"`{original_name}` not found in the data."
self.data[target_name] = self.data.pop(original_name)
def add_textgrid(self, textgrid_file=""):
self.tgrd = tgt.read_textgrid(textgrid_file)
def get_intervals(self, interval_tier_name=""):
assert self.tgrd is not None, "Please add the TextGrid first."
if interval_tier_name:
return self.tgrd.get_tier_by_name(interval_tier_name)
else:
return self.tgrd.tiers[0]
def set_interval(self, interval):
assert type(interval) is dict or type(interval) is tgt.core.Interval, "The interval should be a tgt.core.Interval or a dict containg \"start_time\", \"end_time\", and \"text\"."
if type(interval) is dict:
assert "start_time" in interval and "end_time" in interval and "text" in interval, "The interval dict should contain \"start_time\", \"end_time\", and \"text\"."
if type(interval) is tgt.core.Interval:
start_time, end_time = interval.start_time, interval.end_time
if self.interval_data!={} and self.interval_data["start_time"]==start_time and self.interval_data["end_time"]==end_time: return
self.interval_data = {"start_time": start_time, "end_time": end_time}
def get_data_within_time(self, start_time, end_time, data_name, return_time=False):
assert data_name in self.data, f"`{data_name}` not found."
data = self.data[data_name]
start_id = np.argmin(abs(self.time-start_time))
end_id = np.argmin(abs(self.time-end_time))
if return_time:
return data[start_id:end_id], self.time[start_id:end_id]
else:
return data[start_id:end_id]
def get_interval_data(self, data_name, return_time=False, save=True):
assert self.interval_data!={}, "Please set the interval first."
if data_name not in self.interval_data:
data, time = self.get_data_within_time(self.interval_data["start_time"], self.interval_data["end_time"], data_name, True)
if save:
self.interval_data[data_name] = data
self.interval_data["time"] = time
else:
data, time = self.interval_data[data_name], self.interval_data["time"]
if return_time:
return data, time
return data
def differentiate(self, data_name):
assert data_name in self.data, f"`{data_name}` not found."
return np.gradient(self.data[data_name])
def highpass_filter(self, data_name, cutoff_freq=100, order=2):
assert data_name in self.data, f"`{data_name}` not found."
from scipy import signal
b, a = signal.butter(order, cutoff_freq, btype="high", analog=False, fs=40000)
return signal.filtfilt(b, a, self.data[data_name])
def smooth(self, data_name, box_pts=30):
assert data_name in self.data, f"`{data_name}` not found."
data = self.data[data_name]
box = np.ones(box_pts)/box_pts
return np.convolve(data, box, mode="same")
def add_data(self, datas):
assert type(datas) is dict, "The data to be added should be in a dict. The key(s) are the name(s) of the data. The values are the data or the wave files."
datas_to_add = {}
for data_name in datas:
if type(datas[data_name]) is str:
sampling_frequency, data = wavfile.read(datas[data_name])
assert sampling_frequency==self.sampling_frequency, f"The data to be added `{data_name}` and other data do not have the same smapling frequencies: {sampling_frequency} and {self.sampling_frequency}."
else:
data = datas[data_name]
assert len(data)==self.data_len, f"The data to be added `{data_name}` and other data do not have the same lengths: {len(data)} and {self.data_len}."
datas_to_add[data_name] = np.array(data)
for data_name in datas_to_add:
self.data[data_name] = np.array(datas_to_add[data_name])
def delete_data(self, data_names):
if type(data_names) is str: data_names = [data_names]
to_delete = []
for data_name in data_names:
assert data_name in self.data, f"`{data_name}` not found."
to_delete+=[data_name]
assert len(set(to_delete)-set(self.data.keys()))==0, "You cannot delete every data. At leaset one data should be left."
for data_name in to_delete:
del self.data[data_name]
if "cycles" in self.interval_data and data_name in self.interval_data["cycles"]:
del self.interval_data["cycles"][data_name]
if "cycle_data" in self.interval_data:
if data_name in self.interval_data["cycle_data"]:
del self.interval_data["cycle_data"][data_name]
for this_data_name in self.interval_data["cycle_data"]:
if data_name in self.interval_data["cycle_data"][this_data_name]:
del self.interval_data["cycle_data"][this_data_name][data_name]
def _plot_data(self, data, time, return_plot, w, h):
fig, axs = plt.subplots(len(data), 1, figsize=(w, h*len(data)), sharex=True)
if len(data)==1:
axs = [axs]
for data_name, ax in zip(data, axs):
sns.lineplot(
ax=ax,
x=time,
y=data[data_name],
)
ax.set_ylabel(data_name)
axs[-1].set_xlabel("time")
fig.tight_layout()
if return_plot: return fig
def plot_data_within_time(self, start_time, end_time, data_names=None, return_plot=False, w=10, h=2):
if type(data_names) is str:
data_names = [data_names]
if not data_names:
data_names = list(self.data.keys())
data = {}
for data_name in data_names:
assert data_name in self.data, f"`{data_name}` not found."
data[data_name], time = self.get_data_within_time(start_time, end_time, data_name, return_time=True)
fig = self._plot_data(data, time, return_plot=True, w=w, h=h)
if return_plot:
return fig
def plot_interval_data(self, data_names=None, return_plot=False, w=10, h=2):
if type(data_names) is str:
data_names = [data_names]
if not data_names:
data_names = list(self.data.keys())
data = {}
for data_name in data_names:
assert data_name in self.data, f"`{data_name}` not found."
data[data_name], time = self.get_interval_data(data_name, return_time=True)
fig = self._plot_data(data, time, return_plot=True, w=w, h=h)
if return_plot:
return fig
def get_interval_egg_cycles(self, data_name, min_amp=600, save=True, return_cycles=False):
assert self.interval_data!=None, "Please set the interval first."
assert data_name in self.data, f"`{data_name}` not found."
data, time = self.get_data_within_time(self.interval_data["start_time"], self.interval_data["end_time"], data_name, return_time=True)
egg_cycles = []
start, end = None, None
max_value = -np.inf
min_value = np.inf
for idx in range(len(time)-1):
if data[idx]<=0 and data[idx+1]>=0:
if not start:
start = (time[idx]+time[idx+1])/2
else:
end = (time[idx]+time[idx+1])/2
if max_value-min_value>=min_amp:
egg_cycles+=[(start, end)]
start, end = end, None
max_value = 0
min_value = 0
max_value = max([max_value, data[idx]])
min_value = min([min_value, data[idx]])
assert len(egg_cycles)>0, "No cycle is found in this interval. Consider decreasing the minimum cycle amplitude, or select a wider time range."
print(f"{len(egg_cycles)} cycles found.")
if save:
if "cycles" not in self.interval_data:
self.interval_data["cycles"] = {}
self.interval_data["cycles"][data_name] = egg_cycles
if return_cycles: return egg_cycles
def get_interval_all_cycle_data(self, cycle_name, data_names=None, save=True, return_data=False):
assert self.interval_data!={}, "Please set the interval first."
assert "cycles" in self.interval_data and cycle_name in self.interval_data["cycles"], "Please get the EGG cycles first."
if type(data_names) is str: data_names = [data_names]
if data_names is None:
data_names = list(self.data.keys())
data_names+=["cycle", "time_point"]
data = {}
appended_data_names = []
for data_name in data_names.copy():
assert data_name in ["cycle", "time_point"] or data_name in self.data.keys(), f"`{data_name}` not found."
if "cycle_data" not in self.interval_data or cycle_name not in self.interval_data["cycle_data"] or data_name not in self.interval_data["cycle_data"][cycle_name]:
data[data_name] = np.array([])
else:
data[data_name] = self.interval_data["cycle_data"][cycle_name][data_name]
appended_data_names+=[data_name]
if len(set(data_names)-set(appended_data_names))==0:
if return_data: return data
else: return
for cycle_idx, cycle in enumerate(tqdm(self.interval_data["cycles"][cycle_name])):
start_time, end_time = cycle
start_id = np.argmin(abs(self.time-start_time))
end_id = np.argmin(abs(self.time-end_time))
for data_name in set(data_names)-set(appended_data_names)-{"cycle", "time_point"}:
cycle_data = self.data[data_name][start_id:end_id]
data[data_name] = np.append(data[data_name], cycle_data)
if len(appended_data_names)==0:
data["cycle"] = np.append(data["cycle"], np.array([cycle_idx]*len(cycle_data)))
data["time_point"] = np.append(data["time_point"], np.array(range(len(cycle_data))))
if save:
if "cycle_data" not in self.interval_data:
self.interval_data["cycle_data"] = {}
if cycle_name not in self.interval_data["cycle_data"]:
self.interval_data["cycle_data"][cycle_name] = {}
for data_name in data_names:
if data_name not in self.interval_data["cycle_data"][cycle_name]:
self.interval_data["cycle_data"][cycle_name][data_name] = data[data_name]
if return_data: return data
def plot_interval_all_cycle_data(self, cycle_name, data_names=None, w=10, h=2, save=False):
assert self.interval_data!={}, "Please set the interval first."
if not data_names:
data_names = list(self.data.keys())
n_subplot = len(data_names)
data = self.get_interval_all_cycle_data(cycle_name, data_names, return_data=True, save=save)
fig, axs = plt.subplots(n_subplot, 1, figsize=(w, h*n_subplot), sharex=True)
if n_subplot==1:
axs = [axs]
cmap = sns.color_palette("flare", as_cmap=True)
norm = plt.Normalize(data["cycle"].min(), data["cycle"].max())
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
for data_name, ax in zip(data, axs):
if data_name in ["time_point", "cycle"]: continue
if len(set(data["cycle"]))==1:
sns.lineplot(
ax=ax,
x=data["time_point"],
y=data[data_name]
)
ax.set_ylabel(data_name)
else:
sns.lineplot(
ax=ax,
x=data["time_point"],
y=data[data_name],
hue=data["cycle"],
palette=cmap
)
ax.set_ylabel(data_name)
ax.get_legend().remove()
cb = ax.figure.colorbar(sm, ax=ax)
cb.outline.set_visible(False)
fig.tight_layout()
axs[-1].set_xlabel("time_point")
```