Skip to content

Commit

Permalink
refactor: for huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
xhiroga committed Jan 10, 2024
1 parent a613cda commit 71601f7
Show file tree
Hide file tree
Showing 8 changed files with 12,266 additions and 12,108 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
data/
figures/
models/
models/*

__pycache__
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@ conda create -f environment.yml
conda activate chiikawa-yonezu
pip install fugashi ipadic
```

## Run gradio

```powershell
conda activate chiikawa-yonezu
python app.py
# or
conda run -n chiikawa-yonezu python app.py # not recommended because standard output is not displayed
```
76 changes: 76 additions & 0 deletions computer-science/machine-learning/_src/chiikawa-yonezu/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from pprint import pprint

import gradio as gr
import torch

from safetensors import safe_open
from transformers import BertTokenizer

from utils.ClassifierModel import ClassifierModel


def _classify_text(text, model, device, tokenizer, max_length=20):
"""
テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
"""

# テキストをトークナイズし、PyTorchのテンソルに変換
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
pprint(f"inputs: {inputs}")

# モデルの推論
model.eval()
with torch.no_grad():
outputs = model(
inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
)
pprint(f"outputs: {outputs}")
probabilities = torch.nn.functional.softmax(outputs, dim=1)

# 確率の取得
chiikawa_prob = probabilities[0][0].item()
yonezu_prob = probabilities[0][1].item()

return chiikawa_prob, yonezu_prob


def classify_text(text):
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
pprint(f"device: {device}")

model_save_path = "models/model.safetensors"
tensors = {}
with safe_open(model_save_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)

inference_model: torch.nn.Module = ClassifierModel().to(device)
inference_model.load_state_dict(tensors)

tokenizer = BertTokenizer.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking"
)
chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
return {"ちいかわ": chii_prob, "米津玄師": yone_prob}


demo = gr.Interface(
fn=classify_text,
inputs="textbox",
outputs="label",
examples=[
"炊き立て・・・・ってコト!?",
"晴れた空に種を蒔こう",
],
)

demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ dependencies:
- jaconv=0.3.4=pyhd8ed1ab_0
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.2=pyhd8ed1ab_1
- joblib=1.3.2=pyhd8ed1ab_0
- jupyter_client=8.6.0=pyhd8ed1ab_0
- jupyter_core=5.7.0=py311h1ea47a8_0
- kiwisolver=1.4.5=py311h005e61a_1
Expand Down Expand Up @@ -206,13 +207,16 @@ dependencies:
- regex=2023.12.25=py311ha68e1ae_0
- requests=2.31.0=pyhd8ed1ab_0
- safetensors=0.3.3=py311hc37eb10_1
- scikit-learn=1.3.2=py311h142b183_2
- scipy=1.11.4=py311h0b4df5a_0
- setuptools=69.0.3=pyhd8ed1ab_0
- sip=6.7.12=py311h12c1d0e_0
- six=1.16.0=pyh6c4a22f_0
- snappy=1.1.10=hfb803bf_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.12=pyh04b8f61_3
- tbb=2021.11.0=h91493d7_0
- threadpoolctl=3.2.0=pyha21a80b_0
- tk=8.6.13=h5226925_1
- tokenizers=0.15.0=py311h91c4a10_1
- toml=0.10.2=pyhd8ed1ab_0
Expand Down Expand Up @@ -242,5 +246,40 @@ dependencies:
- zipp=3.17.0=pyhd8ed1ab_0
- zstd=1.5.5=h12be248_0
- pip:
- aiofiles==23.2.1
- altair==5.2.0
- annotated-types==0.6.0
- anyio==4.2.0
- click==8.1.7
- fastapi==0.108.0
- ffmpy==0.3.1
- gradio==4.13.0
- gradio-client==0.8.0
- h11==0.14.0
- httpcore==1.0.2
- httpx==0.26.0
- importlib-resources==6.1.1
- jsonschema==4.20.0
- jsonschema-specifications==2023.12.1
- markdown-it-py==3.0.0
- mdurl==0.1.2
- orjson==3.9.10
- pydantic==2.5.3
- pydantic-core==2.14.6
- pydub==0.25.1
- python-multipart==0.0.6
- referencing==0.32.1
- rich==13.7.0
- rpds-py==0.16.2
- semantic-version==2.10.0
- shellingham==1.5.4
- sniffio==1.3.0
- starlette==0.32.0.post1
- tomlkit==0.12.0
- toolz==0.12.0
- torchaudio==2.1.2
- torchvision==0.16.2
- typer==0.9.0
- uvicorn==0.25.0
- websockets==11.0.3

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"from transformers import BertTokenizer, BertModel\n",
"import torch\n",
"from pprint import pprint\n"
"from pprint import pprint"
]
},
{
Expand Down
Loading

0 comments on commit 71601f7

Please sign in to comment.