From d1a335bc7543a44b1260b01109553e20f9891c0f Mon Sep 17 00:00:00 2001
From: Adam Louly
 <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Date: Mon, 11 Dec 2023 12:56:23 -0800
Subject: [PATCH 1/4] refactor fsdp

---
 examples/onnxruntime/training/language-modeling/run_clm.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/examples/onnxruntime/training/language-modeling/run_clm.py b/examples/onnxruntime/training/language-modeling/run_clm.py
index d4a473993a..8ac31e77f8 100644
--- a/examples/onnxruntime/training/language-modeling/run_clm.py
+++ b/examples/onnxruntime/training/language-modeling/run_clm.py
@@ -408,7 +408,7 @@ def main():
             logger.info(f"Overriding config: {model_args.config_overrides}")
             config.update_from_string(model_args.config_overrides)
             logger.info(f"New config: {config}")
-
+    config.num_hidden_layers = 2
     tokenizer_kwargs = {
         "cache_dir": model_args.cache_dir,
         "use_fast": model_args.use_fast_tokenizer,

From 1558e1768d87455e13d035f3a4a8b905130b1bd7 Mon Sep 17 00:00:00 2001
From: Adam Louly
 <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Date: Mon, 11 Dec 2023 13:10:49 -0800
Subject: [PATCH 2/4] add trainer

---
 optimum/onnxruntime/trainer.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py
index afc90e405b..7c6b0a5227 100644
--- a/optimum/onnxruntime/trainer.py
+++ b/optimum/onnxruntime/trainer.py
@@ -455,7 +455,7 @@ def _inner_training_loop(
             else:
                 debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
 
-        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
+        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
 
         # Wrap the model with `ORTModule`
         logger.info("Wrap ORTModule for ONNX Runtime training.")
@@ -883,7 +883,7 @@ def _wrap_model(self, model, training=True, dataloader=None):
             return model
 
         # Distributed training using PyTorch FSDP
-        if self.fsdp is not None:
+        if self.is_fsdp_xla_enabled:
             try:
                 from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                 from torch_xla.distributed.fsdp import checkpoint_module

From 6408ec6822c92539ebd33553ccee478b9a778163 Mon Sep 17 00:00:00 2001
From: Adam Louly
 <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Date: Mon, 11 Dec 2023 13:14:00 -0800
Subject: [PATCH 3/4] remove hidden layers

---
 examples/onnxruntime/training/language-modeling/run_clm.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/examples/onnxruntime/training/language-modeling/run_clm.py b/examples/onnxruntime/training/language-modeling/run_clm.py
index 8ac31e77f8..d4a473993a 100644
--- a/examples/onnxruntime/training/language-modeling/run_clm.py
+++ b/examples/onnxruntime/training/language-modeling/run_clm.py
@@ -408,7 +408,7 @@ def main():
             logger.info(f"Overriding config: {model_args.config_overrides}")
             config.update_from_string(model_args.config_overrides)
             logger.info(f"New config: {config}")
-    config.num_hidden_layers = 2
+
     tokenizer_kwargs = {
         "cache_dir": model_args.cache_dir,
         "use_fast": model_args.use_fast_tokenizer,

From 87b6adb1f01201817c865f91ad2f5c96dd6c5f71 Mon Sep 17 00:00:00 2001
From: JingyaHuang <huang_jingya@outlook.com>
Date: Tue, 26 Dec 2023 13:38:29 +0000
Subject: [PATCH 4/4] update dockerfile

---
 ...Dockerfile-ort1.16.1-cu118 => Dockerfile-ort1.16.3-cu118} | 5 ++++-
 tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer      | 5 ++++-
 2 files changed, 8 insertions(+), 2 deletions(-)
 rename examples/onnxruntime/training/docker/{Dockerfile-ort1.16.1-cu118 => Dockerfile-ort1.16.3-cu118} (94%)

diff --git a/examples/onnxruntime/training/docker/Dockerfile-ort1.16.1-cu118 b/examples/onnxruntime/training/docker/Dockerfile-ort1.16.3-cu118
similarity index 94%
rename from examples/onnxruntime/training/docker/Dockerfile-ort1.16.1-cu118
rename to examples/onnxruntime/training/docker/Dockerfile-ort1.16.3-cu118
index 482d495fcb..8b6a8c38bd 100644
--- a/examples/onnxruntime/training/docker/Dockerfile-ort1.16.1-cu118
+++ b/examples/onnxruntime/training/docker/Dockerfile-ort1.16.3-cu118
@@ -65,12 +65,15 @@ RUN $PYTHON_EXE -m pip install onnx ninja
 RUN $PYTHON_EXE -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} -f https://download.pytorch.org/whl/${TORCH_CUDA_VERSION}
 
 # ORT Module
-RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.1 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
+RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.3 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
 RUN $PYTHON_EXE -m pip install torch-ort
 ENV TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX"
 RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
 RUN $PYTHON_EXE -m torch_ort.configure
 
+# https://github.com/vllm-project/vllm/issues/1726
+RUN pip uninstall nvidia-nccl-cu12 -y
+
 WORKDIR .
 
 CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer b/tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer
index 7266ba224a..74add3f07e 100644
--- a/tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer
+++ b/tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer
@@ -65,12 +65,15 @@ RUN $PYTHON_EXE -m pip install onnx ninja
 RUN $PYTHON_EXE -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} -f https://download.pytorch.org/whl/${TORCH_CUDA_VERSION}
 
 # ORT Module
-RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.1 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
+RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.3 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
 RUN $PYTHON_EXE -m pip install torch-ort
 ENV TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX"
 RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
 RUN $PYTHON_EXE -m torch_ort.configure
 
+# https://github.com/vllm-project/vllm/issues/1726
+RUN pip uninstall nvidia-nccl-cu12 -y
+
 # Install Optimum
 COPY . /workspace/optimum
 RUN pip install /workspace/optimum[tests]