This repository has been archived by the owner on Feb 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Also optimise Dockerfile for quick rebuilds. Also add multiple input files support to `predict.py`.
- Loading branch information
Showing
5 changed files
with
130 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
FROM conda/miniconda3 | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y libsndfile1 | ||
RUN apt update && apt install -y g++ | ||
|
||
# Copy requirements.txt and run pip first so that changes to the application | ||
# code do not require a rebuild of the entire image | ||
COPY requirements.txt /app/ | ||
RUN conda update conda && \ | ||
conda install "keras<2.4" "numpy<2" "scikit-learn<0.23" && \ | ||
conda install -c conda-forge librosa theano | ||
|
||
ADD . /app | ||
WORKDIR /app | ||
|
||
VOLUME /data | ||
|
||
RUN pip install --upgrade pip && \ | ||
pip install -r requirements.txt | ||
ENV KERAS_BACKEND=theano |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import base64 | ||
import json | ||
import numpy as np | ||
from werkzeug.wrappers import Request, Response | ||
import predict | ||
|
||
|
||
def decode_audio(audio_bytes): | ||
return np.frombuffer(base64.b64decode(audio_bytes), dtype="float32") | ||
|
||
|
||
def make_app(estimate_func): | ||
def app(environ, start_response): | ||
inputs = json.loads(Request(environ).get_data()) | ||
|
||
outputs = [] | ||
for inp in inputs: | ||
try: | ||
est = int(estimate_func(decode_audio(inp))) | ||
except Exception as e: | ||
print(f"Error estimating speaker count for input {len(outputs)}: {e}") | ||
est = None | ||
outputs.append(est) | ||
|
||
return Response(json.dumps(outputs))(environ, start_response) | ||
|
||
return app | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
import functools | ||
from werkzeug.serving import run_simple | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Run simple JSON api server to predict speaker count" | ||
) | ||
parser.add_argument("--model", default="CRNN", help="model name") | ||
args = parser.parse_args() | ||
|
||
model = predict.load_model(args.model) | ||
scaler = predict.load_scaler() | ||
|
||
app = make_app(functools.partial(predict.count, model=model, scaler=scaler)) | ||
run_simple("0.0.0.0", 5000, app, use_debugger=True) |
This file was deleted.
Oops, something went wrong.