Model prediction¶
This notebook downloads an example baseline model. All baseline models are available here although they are beta models and not recommended for research use.
[1]:
# suppress warnings
import warnings
warnings.simplefilter('ignore')
#import modules from Opensoundscape
from opensoundscape.torch.predict import predict
from opensoundscape.datasets import SingleTargetAudioDataset
from opensoundscape.helpers import run_command
from opensoundscape.datasets import SplitterDataset
from opensoundscape.raven import lowercase_annotations
[2]:
import torch
import torch.nn
import torchvision.models
import torch.utils.data
[3]:
import yaml
import os.path
import pandas as pd
from pathlib import Path
from math import floor
Prepare model¶
Download model¶
Download the example model for Wood Thrush, Hylocichla mustelina.
[4]:
def download_from_box(link, name):
run_command(f"curl -L {link} -o ./{name}")
[5]:
folder_name = "prediction_example"
folder_path = Path(folder_name)
if not folder_path.exists(): folder_path.mkdir()
model_filename = folder_path.joinpath("hylocichla-mustelina-epoch-4.model")
download_from_box(
link = "https://pitt.box.com/shared/static/dslgslmag7y8ojqxv28mwhbnt7irpgeo.model",
name = model_filename
)
Load model¶
The model must be loaded with the same specifications that it was created with: a combination of a resnet18
convolutional neural network and a Linear
classifier. This model predicts two “classes”: the presence and absence of Wood Thrush.
[6]:
num_classes = 2
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_filename))
#model.load_state_dict(torch.load("scolopax-minor-epoch-4.model"))
[6]:
<All keys matched successfully>
Prepare prediction files¶
Download an example soundscape which contains Wood Thrush vocalizations.
Download data¶
[7]:
data_filename = folder_path.joinpath("1min.wav")
download_from_box(
link = "https://pitt.box.com/shared/static/z73eked7quh1t2pp93axzrrpq6wwydx0.wav",
name = data_filename
)
Split data¶
The example soundscape must be split up into soundscapes of the same size as the ones the model was trained on. In this case, the soundscapes should be 5s long.
[8]:
files_to_split = [data_filename]
split_directory = folder_path.joinpath("split_files")
if not split_directory.exists(): split_directory.mkdir()
dataset = SplitterDataset(
files_to_split,
overlap=0,
duration=5,
output_directory=split_directory,
include_last_segment=True
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=SplitterDataset.collate_fn,
)
results_csv = folder_path.joinpath("prediction_files.csv")
with open(results_csv, "w") as f:
if False:
f.write("Source,Annotations,Begin (s),End (s),Destination,Labels\n")
else:
f.write("Source,Begin (s),End (s),Destination\n")
for idx, data in enumerate(dataloader):
for output in data:
f.write(f"{output}\n")
Create a Dataset¶
Create a dataset from these data. We create a dictionary that associates numeric labels with the class names: 1 is for predicting a Wood Thrush’s presence; 0 is for predicting a Wood Thrush’s absence.
[9]:
files_to_analyze=list(split_directory.glob("*.wav"))
sample_df = pd.DataFrame(columns=['file'],data=files_to_analyze)
[10]:
label_dict = {0:'absent', 1:'hylocichla-mustelina'}
test_dataset = SingleTargetAudioDataset(
sample_df,
filename_column = "file",
label_dict = label_dict
)
Use model on prediction files¶
[11]:
model.eval()
prediction_df = predict(model, test_dataset, label_dict=label_dict)
prediction_df
[11]:
absent | hylocichla-mustelina | |
---|---|---|
prediction_example/split_files/bc645003351149f4a7e2c7109b22afc1.wav | 0.816133 | -0.903320 |
prediction_example/split_files/e36a0f200cdf42a23d49e78445121387.wav | 1.480433 | -0.927409 |
prediction_example/split_files/4940c91a1837410240042cf55ccad568.wav | 1.940377 | -1.725088 |
prediction_example/split_files/cfc05bd9e1b97eebdca3badc288de0cd.wav | 2.629047 | -1.988923 |
prediction_example/split_files/32747f95e81ee34c56ed177c4f7e7df5.wav | 2.513747 | -2.366485 |
prediction_example/split_files/369134205221b5a25fac0e264d0a1482.wav | 2.351259 | -1.628652 |
prediction_example/split_files/f3d6aeabe7725f649dc56d6db04aa83f.wav | 1.570931 | -1.124706 |
prediction_example/split_files/54534197c0768b6bb2a9305013e8c1af.wav | 1.744635 | -1.055664 |
prediction_example/split_files/e0c2d4aed1d79d4a6194be948d3292da.wav | 1.315882 | -1.407135 |
prediction_example/split_files/9d276a5dd54b631c4aa63da407a1225d.wav | 1.766514 | -1.096341 |
prediction_example/split_files/636f23557581b700f286b7db29d01b61.wav | 0.273381 | -0.397208 |
prediction_example/split_files/e55ba1b5a1316fcda5f3b4d73b2e36ee.wav | 2.138355 | -1.632506 |
This command “cleans up” by deleting all the downloaded files and results.
[12]:
import shutil
shutil.rmtree(folder_path)