diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 664f655a..e8ec0a61 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,7 +16,7 @@ jobs: pytest: strategy: matrix: - os: [windows-latest, macos-latest, ubuntu-latest] + os: [macos-latest, ubuntu-latest] python-version: ["3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/README.md b/README.md index 011a7a2b..389543b8 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ______________________________________________________________________ [![Documentation](https://img.shields.io/badge/docs-passing-green)](https://alexandrainst.github.io/coral_models/coral_models.html) [![License](https://img.shields.io/github/license/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/blob/main/LICENSE) [![LastCommit](https://img.shields.io/github/last-commit/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/commits/main) -[![Code Coverage](https://img.shields.io/badge/Coverage-53%25-orange.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests) +[![Code Coverage](https://img.shields.io/badge/Coverage-54%25-orange.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests) Developers: diff --git a/config/config.yaml b/config/config.yaml index 332cb151..ec8017b8 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -19,6 +19,10 @@ characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü ' max_seconds_per_example: 10 dataloader_num_workers: 8 +# Can be `longest`, `max_length` or `do_not_pad` +# NOTE: This is automatically set to `max_length` in a multi-gpu setting +padding: longest + # This is a list of the sampling probability of each dataset, where null means that # each dataset will be sampled equally often dataset_probabilities: @@ -46,8 +50,8 @@ save_total_limit: 2 learning_rate: 3e-5 adam_first_momentum: 0.9 adam_second_momentum: 0.98 -batch_size: 8 -gradient_accumulation: 32 +total_batch_size: 256 +per_device_batch_size: 16 max_steps: 50_000 warmup_steps: 1_000 logging_steps: 10 diff --git a/config/model/test_wav2vec2.yaml b/config/model/test_wav2vec2.yaml index 5aa6e062..44351548 100644 --- a/config/model/test_wav2vec2.yaml +++ b/config/model/test_wav2vec2.yaml @@ -17,7 +17,7 @@ mask_time_prob: 0.075 mask_time_length: 10 mask_feature_prob: 0.075 mask_feature_length: 10 -layerdrop: 0.1 +layerdrop: 0.0 # NOTE: This parameter cannot be used in a multi-gpu setting! ctc_loss_reduction: sum # Decoder hyperparameters diff --git a/config/model/wav2vec2.yaml b/config/model/wav2vec2.yaml index 5d07214b..768954ca 100644 --- a/config/model/wav2vec2.yaml +++ b/config/model/wav2vec2.yaml @@ -14,11 +14,11 @@ hidden_dropout: 0.0 feat_proj_dropout: 0.0 feat_quantizer_dropout: 0.0 final_dropout: 0.0 -mask_time_prob: 0.3 +mask_time_prob: 0.5 mask_time_length: 10 -mask_feature_prob: 0.3 +mask_feature_prob: 0.5 mask_feature_length: 64 -layerdrop: 0.1 +layerdrop: 0.1 # This will automatically be set to 0 in a multi-gpu setting ctc_loss_reduction: mean # Decoder hyperparameters diff --git a/poetry.lock b/poetry.lock index 7d372a96..c1b902d0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -151,6 +151,17 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -722,6 +733,41 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepspeed" +version = "0.12.3" +description = "DeepSpeed library" +optional = false +python-versions = "*" +files = [ + {file = "deepspeed-0.12.3.tar.gz", hash = "sha256:dc8a0c261589856743c3b3e7bf9829eded2cc8b2464a40456c3a997ed3a01a08"}, +] + +[package.dependencies] +hjson = "*" +ninja = "*" +numpy = "*" +packaging = ">=20.0" +psutil = "*" +py-cpuinfo = "*" +pydantic = "*" +pynvml = "*" +torch = "*" +tqdm = "*" + +[package.extras] +1bit-mpi = ["mpi4py"] +all = ["accelerate", "autodoc_pydantic", "clang-format (==16.0.2)", "coverage", "deepspeed-kernels", "diffusers", "docutils (<0.18)", "future", "google", "hjson", "importlib-metadata (>=4)", "lm-eval (==0.3.0)", "mpi4py", "mup", "neural-compressor (==2.1.0)", "packaging", "pre-commit (>=2.20.0)", "protobuf", "psutil", "py-cpuinfo", "pydantic (<2.0.0)", "pytest", "pytest-forked", "pytest-randomly", "pytest-xdist", "recommonmark", "sphinx", "sphinx-rtd-theme", "sphinx_rtd_theme", "tabulate", "tensorboard", "torch", "torchvision", "tqdm", "transformers", "transformers[sentencepiece]", "triton", "triton (==1.0.0)", "triton (>=2.1.0)", "wandb", "xgboost"] +autotuning = ["tabulate"] +autotuning-ml = ["hjson", "tabulate", "xgboost"] +dev = ["accelerate", "clang-format (==16.0.2)", "coverage", "deepspeed-kernels", "docutils (<0.18)", "future", "importlib-metadata (>=4)", "mup", "pre-commit (>=2.20.0)", "pytest", "pytest-forked", "pytest-randomly", "pytest-xdist", "recommonmark", "sphinx", "sphinx-rtd-theme", "tensorboard", "torchvision", "transformers", "wandb"] +inf = ["google", "lm-eval (==0.3.0)", "protobuf", "transformers", "transformers[sentencepiece]"] +readthedocs = ["autodoc_pydantic", "docutils (<0.18)", "hjson", "packaging", "psutil", "py-cpuinfo", "pydantic (<2.0.0)", "recommonmark", "sphinx_rtd_theme", "torch", "tqdm"] +sd = ["diffusers", "triton"] +sparse = ["neural-compressor (==2.1.0)"] +sparse-attn = ["triton (==1.0.0)"] +triton = ["triton (>=2.1.0)"] + [[package]] name = "dill" version = "0.3.7" @@ -1029,6 +1075,17 @@ gitdb = ">=4.0.1,<5" [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-sugar"] +[[package]] +name = "hjson" +version = "3.1.0" +description = "Hjson, a user interface for JSON." +optional = false +python-versions = "*" +files = [ + {file = "hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89"}, + {file = "hjson-3.1.0.tar.gz", hash = "sha256:55af475a27cf83a7969c808399d7bccdec8fb836a07ddbd574587593b9cdcf75"}, +] + [[package]] name = "huggingface-hub" version = "0.17.3" @@ -1808,6 +1865,33 @@ doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx- extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "ninja" +version = "1.11.1.1" +description = "Ninja is a small build system with a focus on speed" +optional = false +python-versions = "*" +files = [ + {file = "ninja-1.11.1.1-py2.py3-none-macosx_10_9_universal2.macosx_10_9_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:376889c76d87b95b5719fdd61dd7db193aa7fd4432e5d52d2e44e4c497bdbbee"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_i686.manylinux_2_5_i686.whl", hash = "sha256:ecf80cf5afd09f14dcceff28cb3f11dc90fb97c999c89307aea435889cb66877"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:84502ec98f02a037a169c4b0d5d86075eaf6afc55e1879003d6cab51ced2ea4b"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73b93c14046447c7c5cc892433d4fae65d6364bec6685411cb97a8bcf815f93a"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:18302d96a5467ea98b68e1cae1ae4b4fb2b2a56a82b955193c637557c7273dbd"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:aad34a70ef15b12519946c5633344bc775a7656d789d9ed5fdb0d456383716ef"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d491fc8d89cdcb416107c349ad1e3a735d4c4af5e1cb8f5f727baca6350fdaea"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:7563ce1d9fe6ed5af0b8dd9ab4a214bf4ff1f2f6fd6dc29f480981f0f8b8b249"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:9df724344202b83018abb45cb1efc22efd337a1496514e7e6b3b59655be85205"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:3e0f9be5bb20d74d58c66cc1c414c3e6aeb45c35b0d0e41e8d739c2c0d57784f"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:76482ba746a2618eecf89d5253c0d1e4f1da1270d41e9f54dfbd91831b0f6885"}, + {file = "ninja-1.11.1.1-py2.py3-none-win32.whl", hash = "sha256:fa2ba9d74acfdfbfbcf06fad1b8282de8a7a8c481d9dee45c859a8c93fcc1082"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:95da904130bfa02ea74ff9c0116b4ad266174fafb1c707aa50212bc7859aebf1"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:185e0641bde601e53841525c4196278e9aaf4463758da6dd1e752c0a0f54136a"}, + {file = "ninja-1.11.1.1.tar.gz", hash = "sha256:9d793b08dd857e38d0b6ffe9e6b7145d7c485a42dcfea04905ca0cdb6017cc3c"}, +] + +[package.extras] +test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "pytest (>=4.5.0)", "pytest-cov (>=2.7.1)", "pytest-runner (>=5.1)", "pytest-virtualenv (>=1.7.0)", "virtualenv (>=15.0.3)"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2422,6 +2506,17 @@ files = [ [package.extras] test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, +] + [[package]] name = "pyarrow" version = "14.0.1" @@ -2513,6 +2608,142 @@ pygtrie = ">=2.1,<3.0" [package.extras] dev = ["bandit", "black", "codecov", "flake8", "huggingface-hub", "isort (>=5.0.0,<6)", "jupyter", "mypy", "nbconvert", "nbformat", "pydocstyle", "pylint", "pytest", "pytest-cov"] +[[package]] +name = "pydantic" +version = "2.5.1" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-2.5.1-py3-none-any.whl", hash = "sha256:dc5244a8939e0d9a68f1f1b5f550b2e1c879912033b1becbedb315accc75441b"}, + {file = "pydantic-2.5.1.tar.gz", hash = "sha256:0b8be5413c06aadfbe56f6dc1d45c9ed25fd43264414c571135c97dd77c2bedb"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.14.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.14.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:ba44fad1d114539d6a1509966b20b74d2dec9a5b0ee12dd7fd0a1bb7b8785e5f"}, + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4a70d23eedd88a6484aa79a732a90e36701048a1509078d1b59578ef0ea2cdf5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cc24728a1a9cef497697e53b3d085fb4d3bc0ef1ef4d9b424d9cf808f52c146"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab4a2381005769a4af2ffddae74d769e8a4aae42e970596208ec6d615c6fb080"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a12bf088d6fa20e094f9a477bf84bd823651d8b8384f59bcd50eaa92e6a52"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38aed5a1bbc3025859f56d6a32f6e53ca173283cb95348e03480f333b1091e7d"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1767bd3f6370458e60c1d3d7b1d9c2751cc1ad743434e8ec84625a610c8b9195"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7cb0c397f29688a5bd2c0dbd44451bc44ebb9b22babc90f97db5ec3e5bb69977"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9ff737f24b34ed26de62d481ef522f233d3c5927279f6b7229de9b0deb3f76b5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1a39fecb5f0b19faee9a8a8176c805ed78ce45d760259a4ff3d21a7daa4dfc1"}, + {file = "pydantic_core-2.14.3-cp310-none-win32.whl", hash = "sha256:ccbf355b7276593c68fa824030e68cb29f630c50e20cb11ebb0ee450ae6b3d08"}, + {file = "pydantic_core-2.14.3-cp310-none-win_amd64.whl", hash = "sha256:536e1f58419e1ec35f6d1310c88496f0d60e4f182cacb773d38076f66a60b149"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:f1f46700402312bdc31912f6fc17f5ecaaaa3bafe5487c48f07c800052736289"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:88ec906eb2d92420f5b074f59cf9e50b3bb44f3cb70e6512099fdd4d88c2f87c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:056ea7cc3c92a7d2a14b5bc9c9fa14efa794d9f05b9794206d089d06d3433dc7"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:076edc972b68a66870cec41a4efdd72a6b655c4098a232314b02d2bfa3bfa157"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e71f666c3bf019f2490a47dddb44c3ccea2e69ac882f7495c68dc14d4065eac2"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f518eac285c9632be337323eef9824a856f2680f943a9b68ac41d5f5bad7df7c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dbab442a8d9ca918b4ed99db8d89d11b1f067a7dadb642476ad0889560dac79"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0653fb9fc2fa6787f2fa08631314ab7fc8070307bd344bf9471d1b7207c24623"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c54af5069da58ea643ad34ff32fd6bc4eebb8ae0fef9821cd8919063e0aeeaab"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc956f78651778ec1ab105196e90e0e5f5275884793ab67c60938c75bcca3989"}, + {file = "pydantic_core-2.14.3-cp311-none-win32.whl", hash = "sha256:5b73441a1159f1fb37353aaefb9e801ab35a07dd93cb8177504b25a317f4215a"}, + {file = "pydantic_core-2.14.3-cp311-none-win_amd64.whl", hash = "sha256:7349f99f1ef8b940b309179733f2cad2e6037a29560f1b03fdc6aa6be0a8d03c"}, + {file = "pydantic_core-2.14.3-cp311-none-win_arm64.whl", hash = "sha256:ec79dbe23702795944d2ae4c6925e35a075b88acd0d20acde7c77a817ebbce94"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8f5624f0f67f2b9ecaa812e1dfd2e35b256487566585160c6c19268bf2ffeccc"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6c2d118d1b6c9e2d577e215567eedbe11804c3aafa76d39ec1f8bc74e918fd07"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe863491664c6720d65ae438d4efaa5eca766565a53adb53bf14bc3246c72fe0"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:136bc7247e97a921a020abbd6ef3169af97569869cd6eff41b6a15a73c44ea9b"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aeafc7f5bbddc46213707266cadc94439bfa87ecf699444de8be044d6d6eb26f"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e16aaf788f1de5a85c8f8fcc9c1ca1dd7dd52b8ad30a7889ca31c7c7606615b8"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc652c354d3362e2932a79d5ac4bbd7170757a41a62c4fe0f057d29f10bebb"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f1b92e72babfd56585c75caf44f0b15258c58e6be23bc33f90885cebffde3400"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:75f3f534f33651b73f4d3a16d0254de096f43737d51e981478d580f4b006b427"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c9ffd823c46e05ef3eb28b821aa7bc501efa95ba8880b4a1380068e32c5bed47"}, + {file = "pydantic_core-2.14.3-cp312-none-win32.whl", hash = "sha256:12e05a76b223577a4696c76d7a6b36a0ccc491ffb3c6a8cf92d8001d93ddfd63"}, + {file = "pydantic_core-2.14.3-cp312-none-win_amd64.whl", hash = "sha256:1582f01eaf0537a696c846bea92082082b6bfc1103a88e777e983ea9fbdc2a0f"}, + {file = "pydantic_core-2.14.3-cp312-none-win_arm64.whl", hash = "sha256:96fb679c7ca12a512d36d01c174a4fbfd912b5535cc722eb2c010c7b44eceb8e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:71ed769b58d44e0bc2701aa59eb199b6665c16e8a5b8b4a84db01f71580ec448"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:5402ee0f61e7798ea93a01b0489520f2abfd9b57b76b82c93714c4318c66ca06"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaab9dc009e22726c62fe3b850b797e7f0e7ba76d245284d1064081f512c7226"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92486a04d54987054f8b4405a9af9d482e5100d6fe6374fc3303015983fc8bda"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf08b43d1d5d1678f295f0431a4a7e1707d4652576e1d0f8914b5e0213bfeee5"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8ca13480ce16daad0504be6ce893b0ee8ec34cd43b993b754198a89e2787f7e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44afa3c18d45053fe8d8228950ee4c8eaf3b5a7f3b64963fdeac19b8342c987f"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56814b41486e2d712a8bc02a7b1f17b87fa30999d2323bbd13cf0e52296813a1"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3dc2920cc96f9aa40c6dc54256e436cc95c0a15562eb7bd579e1811593c377e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e483b8b913fcd3b48badec54185c150cb7ab0e6487914b84dc7cde2365e0c892"}, + {file = "pydantic_core-2.14.3-cp37-none-win32.whl", hash = "sha256:364dba61494e48f01ef50ae430e392f67ee1ee27e048daeda0e9d21c3ab2d609"}, + {file = "pydantic_core-2.14.3-cp37-none-win_amd64.whl", hash = "sha256:a402ae1066be594701ac45661278dc4a466fb684258d1a2c434de54971b006ca"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:10904368261e4509c091cbcc067e5a88b070ed9a10f7ad78f3029c175487490f"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:260692420028319e201b8649b13ac0988974eeafaaef95d0dfbf7120c38dc000"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c1bf1a7b05a65d3b37a9adea98e195e0081be6b17ca03a86f92aeb8b110f468"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d7abd17a838a52140e3aeca271054e321226f52df7e0a9f0da8f91ea123afe98"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5c51460ede609fbb4fa883a8fe16e749964ddb459966d0518991ec02eb8dfb9"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d06c78074646111fb01836585f1198367b17d57c9f427e07aaa9ff499003e58d"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af452e69446fadf247f18ac5d153b1f7e61ef708f23ce85d8c52833748c58075"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3ad4968711fb379a67c8c755beb4dae8b721a83737737b7bcee27c05400b047"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c5ea0153482e5b4d601c25465771c7267c99fddf5d3f3bdc238ef930e6d051cf"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:96eb10ef8920990e703da348bb25fedb8b8653b5966e4e078e5be382b430f9e0"}, + {file = "pydantic_core-2.14.3-cp38-none-win32.whl", hash = "sha256:ea1498ce4491236d1cffa0eee9ad0968b6ecb0c1cd711699c5677fc689905f00"}, + {file = "pydantic_core-2.14.3-cp38-none-win_amd64.whl", hash = "sha256:2bc736725f9bd18a60eec0ed6ef9b06b9785454c8d0105f2be16e4d6274e63d0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:1ea992659c03c3ea811d55fc0a997bec9dde863a617cc7b25cfde69ef32e55af"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d2b53e1f851a2b406bbb5ac58e16c4a5496038eddd856cc900278fa0da97f3fc"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c7f8e8a7cf8e81ca7d44bea4f181783630959d41b4b51d2f74bc50f348a090f"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8d3b9c91eeb372a64ec6686c1402afd40cc20f61a0866850f7d989b6bf39a41a"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ef3e2e407e4cad2df3c89488a761ed1f1c33f3b826a2ea9a411b0a7d1cccf1b"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f86f20a9d5bee1a6ede0f2757b917bac6908cde0f5ad9fcb3606db1e2968bcf5"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61beaa79d392d44dc19d6f11ccd824d3cccb865c4372157c40b92533f8d76dd0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d41df8e10b094640a6b234851b624b76a41552f637b9fb34dc720b9fe4ef3be4"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c08ac60c3caa31f825b5dbac47e4875bd4954d8f559650ad9e0b225eaf8ed0c"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d8b3932f1a369364606417ded5412c4ffb15bedbcf797c31317e55bd5d920e"}, + {file = "pydantic_core-2.14.3-cp39-none-win32.whl", hash = "sha256:caa94726791e316f0f63049ee00dff3b34a629b0d099f3b594770f7d0d8f1f56"}, + {file = "pydantic_core-2.14.3-cp39-none-win_amd64.whl", hash = "sha256:2494d20e4c22beac30150b4be3b8339bf2a02ab5580fa6553ca274bc08681a65"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:fe272a72c7ed29f84c42fedd2d06c2f9858dc0c00dae3b34ba15d6d8ae0fbaaf"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7e63a56eb7fdee1587d62f753ccd6d5fa24fbeea57a40d9d8beaef679a24bdd6"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7692f539a26265cece1e27e366df5b976a6db6b1f825a9e0466395b314ee48b"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af46f0b7a1342b49f208fed31f5a83b8495bb14b652f621e0a6787d2f10f24ee"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6e2f9d76c00e805d47f19c7a96a14e4135238a7551a18bfd89bb757993fd0933"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:de52ddfa6e10e892d00f747bf7135d7007302ad82e243cf16d89dd77b03b649d"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:38113856c7fad8c19be7ddd57df0c3e77b1b2336459cb03ee3903ce9d5e236ce"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:354db020b1f8f11207b35360b92d95725621eb92656725c849a61e4b550f4acc"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:76fc18653a5c95e5301a52d1b5afb27c9adc77175bf00f73e94f501caf0e05ad"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2646f8270f932d79ba61102a15ea19a50ae0d43b314e22b3f8f4b5fabbfa6e38"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37dad73a2f82975ed563d6a277fd9b50e5d9c79910c4aec787e2d63547202315"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:113752a55a8eaece2e4ac96bc8817f134c2c23477e477d085ba89e3aa0f4dc44"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:8488e973547e8fb1b4193fd9faf5236cf1b7cd5e9e6dc7ff6b4d9afdc4c720cb"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3d1dde10bd9962b1434053239b1d5490fc31a2b02d8950a5f731bc584c7a5a0f"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2c83892c7bf92b91d30faca53bb8ea21f9d7e39f0ae4008ef2c2f91116d0464a"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:849cff945284c577c5f621d2df76ca7b60f803cc8663ff01b778ad0af0e39bb9"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa89919fbd8a553cd7d03bf23d5bc5deee622e1b5db572121287f0e64979476"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf15145b1f8056d12c67255cd3ce5d317cd4450d5ee747760d8d088d85d12a2d"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4cc6bb11f4e8e5ed91d78b9880774fbc0856cb226151b0a93b549c2b26a00c19"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:832d16f248ca0cc96929139734ec32d21c67669dcf8a9f3f733c85054429c012"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b02b5e1f54c3396c48b665050464803c23c685716eb5d82a1d81bf81b5230da4"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:1f2d4516c32255782153e858f9a900ca6deadfb217fd3fb21bb2b60b4e04d04d"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:0a3e51c2be472b7867eb0c5d025b91400c2b73a0823b89d4303a9097e2ec6655"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:df33902464410a1f1a0411a235f0a34e7e129f12cb6340daca0f9d1390f5fe10"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27828f0227b54804aac6fb077b6bb48e640b5435fdd7fbf0c274093a7b78b69c"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e2979dc80246e18e348de51246d4c9b410186ffa3c50e77924bec436b1e36cb"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b28996872b48baf829ee75fa06998b607c66a4847ac838e6fd7473a6b2ab68e7"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca55c9671bb637ce13d18ef352fd32ae7aba21b4402f300a63f1fb1fd18e0364"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:aecd5ed096b0e5d93fb0367fd8f417cef38ea30b786f2501f6c34eabd9062c38"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:44aaf1a07ad0824e407dafc637a852e9a44d94664293bbe7d8ee549c356c8882"}, + {file = "pydantic_core-2.14.3.tar.gz", hash = "sha256:3ad083df8fe342d4d8d00cc1d3c1a23f0dc84fce416eb301e69f1ddbbe124d3f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pydub" version = "0.25.1" @@ -2549,6 +2780,17 @@ files = [ {file = "pygtrie-2.5.0.tar.gz", hash = "sha256:203514ad826eb403dab1d2e2ddd034e0d1534bbe4dbe0213bb0593f66beba4e2"}, ] +[[package]] +name = "pynvml" +version = "11.5.0" +description = "Python Bindings for the NVIDIA Management Library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pynvml-11.5.0-py3-none-any.whl", hash = "sha256:5cce014ac01b098d08f06178f86c37be409b80b2e903a5a03ce15eed60f55e25"}, + {file = "pynvml-11.5.0.tar.gz", hash = "sha256:d027b21b95b1088b9fc278117f9f61b7c67f8e33a787e9f83f735f0f71ac32d0"}, +] + [[package]] name = "pyparsing" version = "3.1.1" @@ -4141,4 +4383,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "b3d1d84f9fd81731f401e79297454321a9984b16a8210ca92ad92e674e0da732" +content-hash = "6b39ddcc354a773f4a1987e0aea6d60d8fe93320a240b107915d8818d23f07d6" diff --git a/pyproject.toml b/pyproject.toml index 59437db1..e2d3e796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ pycountry = "^22.3.5" wave = ">=0.0.2,<1.0.0" kenlm = {url = "https://github.com/kpu/kenlm/archive/master.zip"} matplotlib = "3.7.3" +deepspeed = ">=0.12.3,<1.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.0.0" diff --git a/src/coral_models/compute_metrics.py b/src/coral_models/compute_metrics.py index be4b2b78..745dfafc 100644 --- a/src/coral_models/compute_metrics.py +++ b/src/coral_models/compute_metrics.py @@ -4,9 +4,13 @@ from evaluate.loading import load as load_metric from numpy.typing import NDArray from transformers import EvalPrediction, PreTrainedTokenizerBase +import logging +import os from .protocols import Processor +logger = logging.getLogger(__name__) + def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str, float]: """Compute the word error rate of predictions. @@ -63,7 +67,16 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str, labels[labels == -100] = pad_token # Decode the ground truth labels - labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True) + labels_str = tokenizer.batch_decode( + sequences=labels, skip_special_tokens=True, group_tokens=False + ) + + # TEMP: Log both the predictions and the ground truth labels + is_main_process = os.getenv("RANK", "0") == "0" + if is_main_process: + random_idx = np.random.randint(0, len(predictions_str)) + logger.info(f"Sample document: {labels_str[random_idx]}") + logger.info(f"Predicted: {predictions_str[random_idx]}") # Compute the word error rate computed = wer_metric.compute(predictions=predictions_str, references=labels_str) diff --git a/src/coral_models/data.py b/src/coral_models/data.py index 5ed94bf2..a3011ce7 100644 --- a/src/coral_models/data.py +++ b/src/coral_models/data.py @@ -34,9 +34,13 @@ def load_data(cfg: DictConfig) -> DatasetDict | IterableDatasetDict: ValueError: If the dataset is not supported. """ + # Note if we're on the main process, if we are running in a distributed setting + is_main_process = os.getenv("RANK", "0") == "0" + all_datasets: list[DatasetDict | IterableDatasetDict] = list() for dataset_name, dataset_cfg in cfg.datasets.items(): - logger.info(f"Loading dataset {dataset_name!r}") + if is_main_process: + logger.info(f"Loading dataset {dataset_name!r}") # Load from disk if the dataset ID is a path if Path(dataset_cfg.id).exists(): @@ -126,14 +130,16 @@ def load_data(cfg: DictConfig) -> DatasetDict | IterableDatasetDict: assert len(all_datasets) > 0, "No datasets were loaded" if len(all_datasets) > 1: - logger.info("Interleaving datasets") - if cfg.dataset_probabilities["train"] is None and len(all_datasets) > 1: - logger.warning( - "No dataset probabilities were specified for the training split. " - "This means that each dataset will be sampled with equal probability, " - "which means that the smaller datasets will be sampled more often than " - "the larger datasets. This is probably not what you want." - ) + if is_main_process: + logger.info("Interleaving datasets") + if cfg.dataset_probabilities["train"] is None and len(all_datasets) > 1: + logger.warning( + "No dataset probabilities were specified for the training split. " + "This means that each dataset will be sampled with equal " + "probability, which means that the smaller datasets will be " + "sampled more often than the larger datasets. This is probably " + "not what you want." + ) probabilities: dict[str, list[float]] = dict() for split_name, split_probs in cfg.dataset_probabilities.items(): diff --git a/src/coral_models/finetune.py b/src/coral_models/finetune.py index 089d64b8..723271bf 100644 --- a/src/coral_models/finetune.py +++ b/src/coral_models/finetune.py @@ -3,16 +3,17 @@ from functools import partial import logging from typing import Callable +import os from omegaconf import DictConfig from transformers import EarlyStoppingCallback, TrainerCallback from wandb.sdk.wandb_init import init as wandb_init from wandb.sdk.wandb_run import finish as wandb_finish +from .utils import disable_tqdm from .data import load_data from .model_setup import load_model_setup from .protocols import ModelSetup -from .utils import disable_tqdm logger = logging.getLogger(__package__) @@ -64,6 +65,9 @@ def finetune(cfg: DictConfig) -> None: Args: cfg: The Hydra cfguration object. """ + # Note if we're on the main process, if we are running in a distributed setting + is_main_process = os.getenv("RANK", "0") == "0" + model_setup: ModelSetup = load_model_setup(cfg) processor = model_setup.load_processor() processor.save_pretrained(cfg.model_dir) @@ -81,7 +85,7 @@ def finetune(cfg: DictConfig) -> None: ), ) - if cfg.wandb: + if cfg.wandb and is_main_process: wandb_init( project=cfg.wandb_project, group=cfg.wandb_group, @@ -89,7 +93,7 @@ def finetune(cfg: DictConfig) -> None: config=dict(cfg), ) - if "val" not in dataset: + if "val" not in dataset and is_main_process: logger.info("No validation set found. Disabling early stopping.") trainer = model_setup.load_trainer_class()( @@ -105,11 +109,12 @@ def finetune(cfg: DictConfig) -> None: with disable_tqdm(): trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) - wandb_finish() - model.save_pretrained(cfg.model_dir) - if cfg.push_to_hub: - trainer.push_to_hub() + if is_main_process: + wandb_finish() + model.save_pretrained(cfg.model_dir) + if cfg.push_to_hub: + trainer.push_to_hub() def load_early_stopping_callback(cfg: DictConfig) -> list[TrainerCallback]: diff --git a/src/coral_models/wav2vec2.py b/src/coral_models/wav2vec2.py index 68ec2b10..1092f3c6 100644 --- a/src/coral_models/wav2vec2.py +++ b/src/coral_models/wav2vec2.py @@ -7,6 +7,8 @@ from functools import partial from pathlib import Path from typing import Callable, Type +import time +import os import torch from omegaconf import DictConfig @@ -58,7 +60,7 @@ class DataCollatorCTCWithPadding(DataCollatorMixin): """ processor: Wav2Vec2Processor - padding: bool | str = True + padding: bool | str return_tensors: str = "pt" def torch_call(self, features: list[dict]) -> BatchFeature: @@ -81,12 +83,18 @@ def torch_call(self, features: list[dict]) -> BatchFeature: "Features must contain either 'input_values' or 'audio' key." ) batch: BatchFeature = self.processor.pad( - audio_features, padding=self.padding, return_tensors="pt" + audio_features, + padding=self.padding, + return_tensors=self.return_tensors, + max_length=16_000 * 10, ) label_features = [dict(input_ids=feature["labels"]) for feature in features] labels_batch: BatchEncoding = self.processor.tokenizer.pad( - label_features, padding=self.padding, return_tensors="pt" + label_features, + padding=self.padding, + return_tensors=self.return_tensors, + max_length=512, ) # Replace padding with -100 to ignore loss correctly @@ -112,15 +120,25 @@ def __init__(self, cfg: DictConfig) -> None: def load_processor(self) -> Wav2Vec2Processor: # We dump the vocabulary to a file since the tokenizer uses this file during # initialisation - dump_vocabulary(self.cfg) - tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained( - self.cfg.model_dir, - unk_token="", - pad_token="", - bos_token="", - eos_token="", - word_delimiter_token=" ", - ) + while True: + try: + dump_vocabulary(self.cfg) + tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained( + self.cfg.model_dir, + unk_token="", + pad_token="", + bos_token="", + eos_token="", + word_delimiter_token=" ", + ) + break + except json.decoder.JSONDecodeError: + process_id = os.getenv("RANK", 0) + logger.warning( + f"JSONDecodeError while loading tokenizer on process {process_id}. " + "Retrying in a second." + ) + time.sleep(1) # Set the `model_max_length` attribute of the tokenizer, if it hasn't been set, # to ensure that truncation is done correctly @@ -170,7 +188,9 @@ def load_model(self) -> Wav2Vec2ForCTC: return model def load_data_collator(self) -> DataCollatorCTCWithPadding: - return DataCollatorCTCWithPadding(processor=self.processor, padding=True) + return DataCollatorCTCWithPadding( + processor=self.processor, padding=self.cfg.padding + ) def load_trainer_class(self) -> Type[Trainer]: return Trainer @@ -179,6 +199,23 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: return partial(compute_wer_metrics, processor=self.processor) def load_training_arguments(self) -> TrainingArguments: + # Compute the gradient accumulation based on the total batch size in the config + num_devices = max(torch.cuda.device_count(), 1) + per_device_total_batch_size = self.cfg.total_batch_size // num_devices + gradient_accumulation_steps = ( + per_device_total_batch_size // self.cfg.per_device_batch_size + ) + + if gradient_accumulation_steps == 0: + logger.warning( + f"Your `total_batch_size` is too small ({self.cfg.total_batch_size}), " + f"relative to the number of devices ({num_devices}) and your " + f"`per_device_batch_size` ({self.cfg.per_device_batch_size}). It has " + f"been set to `per_device_batch_size * num_devices` = " + f"{self.cfg.per_device_batch_size * num_devices}." + ) + gradient_accumulation_steps = 1 + do_eval = any( [ dataset_cfg.val_name is not None @@ -188,9 +225,9 @@ def load_training_arguments(self) -> TrainingArguments: args = TrainingArguments( output_dir=self.cfg.model_dir, hub_model_id=self.cfg.hub_id, - per_device_train_batch_size=self.cfg.batch_size, - per_device_eval_batch_size=self.cfg.batch_size, - gradient_accumulation_steps=self.cfg.gradient_accumulation, + per_device_train_batch_size=self.cfg.per_device_batch_size, + per_device_eval_batch_size=self.cfg.per_device_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, lr_scheduler_type=SchedulerType.COSINE, warmup_steps=self.cfg.warmup_steps, @@ -200,6 +237,7 @@ def load_training_arguments(self) -> TrainingArguments: evaluation_strategy="steps" if do_eval else "no", eval_steps=self.cfg.eval_steps if do_eval else None, save_steps=self.cfg.save_steps, + save_strategy="no" if self.cfg.save_total_limit == 0 else "steps", logging_steps=self.cfg.logging_steps, length_column_name="input_length", gradient_checkpointing=True, @@ -217,6 +255,7 @@ def load_training_arguments(self) -> TrainingArguments: save_safetensors=True, use_cpu=hasattr(sys, "_called_from_test"), dataloader_num_workers=self.cfg.dataloader_num_workers, + ddp_find_unused_parameters=False, ) return args @@ -236,7 +275,7 @@ def load_saved(self) -> PreTrainedModelData: model = Wav2Vec2ForCTC.from_pretrained(self.cfg.hub_id, token=True) data_collator = DataCollatorCTCWithPadding( - processor=processor, padding="longest" + processor=processor, padding=self.cfg.padding ) compute_metrics = partial(compute_wer_metrics, processor=processor) return PreTrainedModelData( diff --git a/src/coral_models/whisper.py b/src/coral_models/whisper.py index ecd6de8e..3ade4940 100644 --- a/src/coral_models/whisper.py +++ b/src/coral_models/whisper.py @@ -7,6 +7,7 @@ from typing import Callable, Type from omegaconf import DictConfig +import torch from torch.backends.mps import is_available as mps_is_available from transformers import ( BatchFeature, @@ -161,7 +162,7 @@ def load_model(self) -> WhisperForConditionalGeneration: def load_data_collator(self) -> DataCollatorSpeechSeq2SeqWithPadding: return DataCollatorSpeechSeq2SeqWithPadding( - processor=self.processor, padding=True + processor=self.processor, padding=self.cfg.padding ) def load_trainer_class(self) -> Type[Trainer]: @@ -171,6 +172,23 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: return partial(compute_wer_metrics, processor=self.processor) def load_training_arguments(self) -> TrainingArguments: + # Compute the gradient accumulation based on the total batch size in the config + num_devices = max(torch.cuda.device_count(), 1) + per_device_total_batch_size = self.cfg.total_batch_size // num_devices + gradient_accumulation_steps = ( + per_device_total_batch_size // self.cfg.per_device_batch_size + ) + + if gradient_accumulation_steps == 0: + logger.warning( + f"Your `total_batch_size` is too small ({self.cfg.total_batch_size}), " + f"relative to the number of devices ({num_devices}) and your " + f"`per_device_batch_size` ({self.cfg.per_device_batch_size}). It has " + f"been set to `per_device_batch_size * num_devices` = " + f"{self.cfg.per_device_batch_size * num_devices}." + ) + gradient_accumulation_steps = 1 + do_eval = any( [ dataset_cfg.val_name is not None @@ -180,9 +198,9 @@ def load_training_arguments(self) -> TrainingArguments: args = Seq2SeqTrainingArguments( output_dir=self.cfg.model_dir, hub_model_id=self.cfg.hub_id, - per_device_train_batch_size=self.cfg.batch_size, - per_device_eval_batch_size=self.cfg.batch_size, - gradient_accumulation_steps=self.cfg.gradient_accumulation, + per_device_train_batch_size=self.cfg.per_device_batch_size, + per_device_eval_batch_size=self.cfg.per_device_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, warmup_steps=self.cfg.warmup_steps, max_steps=self.cfg.max_steps, @@ -208,6 +226,7 @@ def load_training_arguments(self) -> TrainingArguments: generation_max_length=self.cfg.model.generation_max_length, use_cpu=hasattr(sys, "_called_from_test"), dataloader_num_workers=self.cfg.dataloader_num_workers, + ddp_find_unused_parameters=False, ) return args @@ -219,7 +238,7 @@ def load_saved(self) -> PreTrainedModelData: self.cfg.hub_id, token=True ) data_collator = DataCollatorSpeechSeq2SeqWithPadding( - processor=processor, padding="longest" + processor=processor, padding=self.cfg.padding ) compute_metrics = partial(compute_wer_metrics, processor=processor) return PreTrainedModelData( diff --git a/src/scripts/finetune_model.py b/src/scripts/finetune_model.py index b1e9b197..dd77d0ca 100644 --- a/src/scripts/finetune_model.py +++ b/src/scripts/finetune_model.py @@ -6,10 +6,15 @@ import hydra from omegaconf import DictConfig +import os +import logging from coral_models.finetune import finetune +logger = logging.getLogger(__name__) + + @hydra.main(config_path="../../config", config_name="config", version_base=None) def main(cfg: DictConfig) -> None: """Finetune an ASR model. @@ -18,6 +23,25 @@ def main(cfg: DictConfig) -> None: cfg (DictConfig): The Hydra configuration object. """ + # In case we are running in a multi-GPU setting, we need to force certain + # hyperparameters + is_main_process = os.getenv("RANK", "0") == "0" + if os.getenv("WORLD_SIZE") is not None: + if "layerdrop" in cfg.model and cfg.model.layerdrop != 0.0: + if is_main_process: + logger.info( + "Forcing `layerdrop` to be 0.0 as this is required in a multi-GPU " + "training" + ) + cfg.model.layerdrop = 0.0 + if cfg.padding != "max_length": + if is_main_process: + logger.info( + "Forcing `padding` to be 'max_length' as this is required in a " + "multi-GPU training" + ) + cfg.padding = "max_length" + finetune(cfg) diff --git a/tests/conftest.py b/tests/conftest.py index 92f6bba2..60d932c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,8 +43,8 @@ def cfg(request) -> Generator[DictConfig, None, None]: f"model={model}", f"datasets={datasets}", "fp16=false", - "batch_size=2", - "gradient_accumulation=1", + "total_batch_size=2", + "per_device_batch_size=2", "max_steps=2", "save_total_limit=0", ],