Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes phase follow #8

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
126 changes: 112 additions & 14 deletions examples/3-event-analysis.ipynb

Large diffs are not rendered by default.

173 changes: 152 additions & 21 deletions lightguide/blast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import deque

import logging
from copy import deepcopy
Expand All @@ -8,6 +9,7 @@
TYPE_CHECKING,
Any,
Callable,
Deque,
Iterable,
Iterator,
Literal,
Expand All @@ -19,11 +21,14 @@
import numpy as np
from matplotlib import colors, dates
from matplotlib.colors import Colormap
from pyrocko import io
from pyrocko import io, pile, obspy_compat, trace
from pyrocko.trace import Trace
from scipy import signal
import re
import math

from lightguide.utils import PathStr
from lightguide.models.picks import *

from .filters import afk_filter
from .signal import decimation_coefficients
Expand Down Expand Up @@ -57,6 +62,7 @@ class Blast:

start_channel: int
channel_spacing: float
channel_list: np.ndarray

def __init__(
self,
Expand All @@ -65,6 +71,7 @@ def __init__(
sampling_rate: float,
start_channel: int = 0,
channel_spacing: float = 0.0,
channel_list: list = [],
unit: MeasurementUnit = "strain rate",
) -> None:
"""Create a new blast from NumPy array.
Expand All @@ -91,6 +98,10 @@ def __init__(

self.start_channel = start_channel
self.channel_spacing = channel_spacing
self.channel_list = channel_list

if len(self.channel_list) == 0:
self.channel_list = np.arange(start_channel, len(data), 1)

self.processing_flow = []

Expand All @@ -112,7 +123,7 @@ def n_channels(self) -> int:
@property
def end_channel(self) -> int:
"""End Channel."""
return self.start_channel + self.n_channels
return self.channel_list[-1]

@property
def n_samples(self) -> int:
Expand All @@ -124,18 +135,67 @@ def duration(self) -> float:
"""Duration in seconds."""
return self.n_samples * self.delta_t

def reduce_channels(self, n: int) -> None:
"""Returns sparsed blast containing only every n-th channel"""
self.data = self.data[:-1:n, :]
print(self.data.shape[:])
self.channel_spacing = self.channel_spacing * n
self.channel_list = self.channel_list[:-1:n]

def exlude_channel(self, channel) -> None:
"""Deletes selected channel, in-place.
Args:
channel (int): number of channel to be removed.
"""
idx = self.get_channel_index(channel, strict=True)
if idx == None:
print(f"#{channel} not in list.")
return self
self.data = np.delete(self.data, idx, 0)
self.channel_list = np.delete(self.channel_list, idx)

def exlude_channels(self, channels) -> None:
"""Deletes channels given in list from blast, in-place."""
for channel in channels:
self.exlude_channel(channel=channel)

def get_channel_name(self, channel_index: int) -> int:
"""Gets name of channel from it's index as given in channel_list.
Args:
channel_index (int): index of channel of interest.
Returns:
int: Channel name.
"""
return self.channel_list[channel_index]

def get_channel_index(self, channel: int, strict=False) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be always strict. Or is there advantage in raising an exception only sometimes?

"""Finds index of a given channel or channel closest to it and returns it.
Args:
channel (int): Channel name.
strict (bool): if False, return channel closest to channel
Returns:
int: channel index
"""
channels = self.channel_list
idx = (np.abs(channels - channel)).argmin()
# idx = np.searchsorted(channels, channel, side="left") # maybe faster??
if channels[idx] == channel:
return idx
elif strict == False:
print(f"#{channel} not in channel list. #{channels[idx]} is used instead.")
return idx
return None

def get_trace(self, channel: int) -> np.ndarray:
"""Get data from a singular channel.

Args:
channel (int): Channel number.
channel (int): Channel name.

Returns:
np.ndarray: 1D Trace.
"""
if not self.start_channel <= channel < self.end_channel:
raise ValueError(f"Channel {channel} is out of bounds")
return self.data[channel - self.start_channel]
return self.data[self.get_channel_index(channel)]

def _time_to_sample(self, time: datetime) -> int:
"""Get sample index for a time.
Expand Down Expand Up @@ -323,13 +383,23 @@ def afk_filter(
normalize_power=normalize_power,
)

def average_traces(self, no_of_traces) -> Blast:
"""Average over number of neighbouring traces, in place.
Args:
no_of_traces (int): number of channels to be used for averaging
"""
kernel = np.ones(shape=(no_of_traces, 1)) / no_of_traces
avs = signal.fftconvolve(self.data, kernel, mode="valid")
self.data = avs

def follow_phase(
self,
pick_time: datetime,
pick_channel: int,
window_size: int | tuple[int, int] = 50,
threshold: float = 5e-1,
max_shift: int = 20,
template_stacks: int = 1,
) -> tuple[np.ndarray, list[datetime], np.ndarray]:
"""Follow a phase pick through a Blast.

Expand All @@ -340,7 +410,7 @@ def follow_phase(
2. Calculate normalized cross correlate with downwards neighbor.
3. Evaluate maximum x-correlation in allowed window (max_shift).
4. Update template trace and go to 2.

4a. if template_stacks > 1: stack templates for correlation to stabilize
5. Repeat for upward neighbors.

Args:
Expand All @@ -353,6 +423,10 @@ def follow_phase(
Defaults to 5e-1.
max_shift (int, optional): Maximum allowed shift in samples for
neighboring picks. Defaults to 20.
template_stacks (int): Numbers of traces to stack to define the template. Default is 1,
i.e. a single trace.
Stacking close to root template is limited by the distance to the
root template.

Returns:
tuple[np.ndarray, list[datetime], np.ndarray]: Tuple of channel number,
Expand All @@ -362,6 +436,7 @@ def follow_phase(
window_size = (window_size, window_size)

pick_channel -= self.start_channel

root_idx = self._time_to_sample(pick_time)

# Ensure the window is odd-sized with the pick in center.
Expand All @@ -372,15 +447,19 @@ def follow_phase(

pick_channels, pick_times, pick_correlations = [], [], []

def prepare_template(data: np.ndarray) -> np.ndarray:
def prepare_template(data: Deque[np.ndarray]) -> np.ndarray:
data = np.mean(data, axis=0)
return data * template_taper

def correlate(data: np.ndarray, direction: Literal[1, -1] = 1) -> None:
template = root_template.copy()
index = root_idx
template_stack: Deque[np.ndarray] = deque(
[np.array(template)], maxlen=template_stacks
)

index = root_idx
for ichannel, trace in enumerate(data):
template = prepare_template(template)
template = prepare_template(template_stack)
norm = np.sqrt(np.sum(template**2)) * np.sqrt(np.sum(trace**2))
correlation = np.correlate(trace, template, mode="same")
correlation = np.abs(correlation / norm)
Expand Down Expand Up @@ -409,15 +488,21 @@ def correlate(data: np.ndarray, direction: Literal[1, -1] = 1) -> None:
template = trace[
phase_idx - window_size[0] : phase_idx + window_size[1] + 1
].copy()

# stacking
template_stack.append(template)
index = phase_idx

correlate(self.data[pick_channel:])
correlate(self.data[: pick_channel - 1][::-1], direction=-1)

pick_channels = np.array(pick_channels) + self.start_channel
pick_correlations = np.array(pick_correlations)

return pick_channels, pick_times, pick_correlations
return Picks(
channel=pick_channels.tolist(),
time=pick_times,
correlation=pick_correlations,
)

def taper(self, alpha: float = 0.05) -> None:
"""Taper in time-domain and in-place with a Tukey window.
Expand Down Expand Up @@ -459,7 +544,10 @@ def trim_channels(self, begin: int = 0, end: int = -1) -> Blast:
Blast: Trimmed Blast.
"""
blast = self.copy()
begin = blast.get_channel_index(begin, strict=False)
end = blast.get_channel_index(end, strict=False)
blast.start_channel += begin
blast.channel_list = blast.channel_list[begin:end]
blast.data = blast.data[begin:end]
return blast

Expand All @@ -486,6 +574,34 @@ def trim_time(self, begin: float = 0.0, end: float = -1.0) -> Blast:
blast.start_time += timedelta(seconds=begin)
return blast

def trim_from_picks(self, picks: Picks, time_window: int = 1):
"""Trims channels to a given time window after a pick time.

Args:
picks (Picks):
time_window (int): time window after pick
"""
blast = self.copy()
blast = blast.as_traces()

channels = picks.channel
times = picks.time

trimmed_traces = []
for channel, time in zip(channels, times):
time = time.timestamp()
# find channel
tr = next((x for x in blast if int(x.station) == channel), None)

# check if marker is in time range of trace
if not tr.time_span[0] <= time <= tr.time_span[1]:
continue

trchop = tr.chop(tmin=time, tmax=time + time_window)
trimmed_traces.append(trchop)

return trimmed_traces

def to_strain(self, detrend: bool = True) -> Blast:
"""Convert the traces to strain.

Expand Down Expand Up @@ -590,7 +706,7 @@ def plot(
dates.date2num(self.start_time) if show_date else 0.0,
)

data = self.data.copy()
data = self.data.copy().astype(float)
if normalize_traces:
data /= np.abs(data.max(axis=1, keepdims=True))

Expand Down Expand Up @@ -653,6 +769,20 @@ def as_traces(self) -> list[Trace]:
)
return traces

def to_obspy_stream(self):
"""Converts blast to an obspy stream

Returns:
Obspy stream containing traces of blast.
"""
p = pile.Pile()
p.add(self.as_traces())
return obspy_compat.to_obspy_stream(p)

def snuffle(self, **kwargs) -> None:
"""Show traces of blast in a snuffler window."""
trace.snuffle(self.as_traces(), **kwargs)

@classmethod
def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blast:
"""Create Blast from a list of Pyrocko traces.
Expand All @@ -671,7 +801,8 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
if not traces:
raise ValueError("Empty list of traces")

traces = sorted(traces, key=lambda tr: int(tr.station))
traces = sorted(traces, key=lambda tr: int(re.sub(r"\D", "", tr.station)))
channel_list = np.array([int(re.sub(r"\D", "", tr.station)) for tr in traces])
ntraces = len(traces)

tmin = set()
Expand Down Expand Up @@ -702,8 +833,10 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
data=data,
start_time=datetime.fromtimestamp(tmin.pop(), tz=timezone.utc),
sampling_rate=int(1.0 / delta_t.pop()),
start_channel=min(int(tr.station) for tr in traces),
# start_channel=min(int(re.sub(r"\D", "", tr.station)) for tr in traces),
start_channel=channel_list[0],
channel_spacing=channel_spacing,
channel_list=channel_list,
)

@classmethod
Expand All @@ -721,7 +854,10 @@ def from_miniseed(cls, file: PathStr, channel_spacing: float = 4.0) -> Blast:
from pyrocko import io

traces = io.load(str(file), format="mseed")
return cls.from_pyrocko(traces, channel_spacing=channel_spacing)
return cls.from_pyrocko(
traces,
channel_spacing=channel_spacing,
)


TFun = TypeVar("TFun", bound=Callable[..., Any])
Expand Down Expand Up @@ -791,11 +927,6 @@ def __len__(self) -> int:

mute_median = shared_function(Blast.mute_median)
one_bit_normalization = shared_function(Blast.one_bit_normalization)
afk_filter = shared_function(Blast.afk_filter)
decimate = shared_function(Blast.decimate)

trim_time = shared_function(Blast.trim_time)
trim_channels = shared_function(Blast.trim_channels)

to_strain = shared_function(Blast.to_strain)
to_relative_velocity = shared_function(Blast.to_relative_velocity)
Expand Down
Loading