Skip to content

Commit

Permalink
apply upsampling before the network
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirMardan committed Jan 24, 2025
1 parent faed7f1 commit be953be
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ preprocessed/
picked_fb_test_data/
example/checkpoints/

*.egg-info
*.egg-info
0_fbp_env/
build/
5 changes: 2 additions & 3 deletions example/0_ex_train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#%% ========== Loading required packages
import pandas as pd
import sys
import os.path as osp
import matplotlib.pyplot as plt
from time import time
from pathlib import Path
import shutil
from matplotlib.ticker import MaxNLocator

sys.path.append(osp.abspath(osp.join(__file__, "../../")))
# import sys
# sys.path.append(osp.abspath(osp.join(__file__, "../../")))
from first_break_picking import train
from first_break_picking.tools import seed_everything
from first_break_picking.data import save_shots_fb
Expand Down
4 changes: 2 additions & 2 deletions example/1_ex_predict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#%% ========== Loading required packages
import sys
import os.path as osp
from pathlib import Path
import shutil

sys.path.append(osp.abspath(osp.join(__file__, "../../")))
# import sys
# sys.path.append(osp.abspath(osp.join(__file__, "../../")))
from first_break_picking.data import save_shots_fb
from first_break_picking import predict

Expand Down
2 changes: 1 addition & 1 deletion example/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#%% ============ Initiate ===============
num_epcohs = 5
num_epcohs = 20

batch_size = 15
split_nt = 22
Expand Down
12 changes: 8 additions & 4 deletions first_break_picking/train_eval/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def predict(base_dir: str,
with torch.no_grad():
if validation:
true_masks = []
for shot_number, (batch, true_mask, fbt_file_name ) in enumerate(loop):
for shot_number, (batch, true_mask, fbt_file_name) in enumerate(loop):
fbt_file_name = fbt_file_name[0]

shot1, predicted_pick, predicted_segment, true_mask1 = predict_validation(
Expand All @@ -358,7 +358,7 @@ def predict(base_dir: str,
overlap=overlap,
shot_id=fbt_file_name,
smoothing_threshold=smoothing_threshold,
upsampler=upsampler,
# upsampler=upsampler,
data_info=data_info,
case_specific_parameters=case_specific_parameters
)
Expand All @@ -373,14 +373,18 @@ def predict(base_dir: str,
for shot_number, (batch, fbt_file_name) in enumerate(loop):
fbt_file_name = fbt_file_name[0]

# nsp: data.shape=[1, 3, 1, 512, 22]
batch, _ = upsampler(batch.squeeze(0), batch.squeeze(0))
#nsp: data.shape= [3, 1, 512, 512])

shot, predicted_pick, predicted_segment = predict_test(
batch=batch,
batch=batch.unsqueeze(0),
model=model,
split_nt=split_nt,
overlap=overlap,
shot_id=fbt_file_name,
smoothing_threshold=smoothing_threshold,
upsampler=upsampler,
# upsampler=upsampler,
data_info=data_info,
case_specific_parameters=case_specific_parameters
)
Expand Down

0 comments on commit be953be

Please sign in to comment.