diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h
index 97a382ee8b..b88cddcc18 100644
--- a/include/flexflow/flexflow_c.h
+++ b/include/flexflow/flexflow_c.h
@@ -55,6 +55,9 @@ FF_NEW_OPAQUE_TYPE(flexflow_inference_manager_t);
 FF_NEW_OPAQUE_TYPE(flexflow_request_manager_t);
 FF_NEW_OPAQUE_TYPE(flexflow_file_data_loader_t);
 FF_NEW_OPAQUE_TYPE(flexflow_generation_result_t);
+// FF_NEW_OPAQUE_TYPE(flexflow_lora_optimizer_config_t);
+// FF_NEW_OPAQUE_TYPE(flexflow_lora_sgd_optimizer_config_t);
+// FF_NEW_OPAQUE_TYPE(flexflow_lora_adam_optimizer_config_t);
 FF_NEW_OPAQUE_TYPE(flexflow_lora_linear_config_t);
 FF_NEW_OPAQUE_TYPE(flexflow_peft_model_id_t);
 
@@ -1050,16 +1053,93 @@ void flexflow_file_data_loader_destroy(flexflow_file_data_loader_t handle_);
 void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_,
                                             flexflow_model_t model_handle_);
 
+// // -----------------------------------------------------------------------
+// // LoraSGDOptimizerConfig
+// // -----------------------------------------------------------------------
+
+// flexflow_lora_sgd_optimizer_config_t
+// flexflow_lora_sgd_optimizer_config_create(
+//     double lr, double momentum, bool nesterov, bool weight_decay);
+
+// void flexflow_lora_sgd_optimizer_config_destroy(
+//     flexflow_lora_sgd_optimizer_config_t handle_);
+
+// // -----------------------------------------------------------------------
+// // LoraAdamOptimizerConfig
+// // -----------------------------------------------------------------------
+
+// flexflow_lora_adam_optimizer_config_t
+//     flexflow_lora_adam_optimizer_config_create(double alpha,
+//                                                double beta1,
+//                                                double beta2,
+//                                                double weight_decay,
+//                                                double epsilon);
+
+// void flexflow_lora_adam_optimizer_config_destroy(
+//     flexflow_lora_adam_optimizer_config_t handle_);
+
 // -----------------------------------------------------------------------
 // LoraLinearConfig
 // -----------------------------------------------------------------------
 
 flexflow_lora_linear_config_t
     flexflow_lora_linear_config_create(char const *cache_folder_,
-                                       char const *peft_model_id_);
+                                       char const *peft_model_id_,
+                                       bool trainable,
+                                       bool init_lora_weights,
+                                       int rank,
+                                       float lora_alpha,
+                                       float lora_dropout,
+                                       int num_target_modules,
+                                       char const **target_modules_,
+                                       enum OptimizerType optimizer_type,
+                                       float sgd_learning_rate,
+                                       float sgd_momentum,
+                                       bool sgd_nesterov,
+                                       float sgd_weight_decay,
+                                       float adam_alpha,
+                                       float adam_beta1,
+                                       float adam_beta2,
+                                       float adam_weight_decay,
+                                       float adam_epsilon);
 
 void flexflow_lora_linear_config_destroy(flexflow_lora_linear_config_t handle_);
 
+char const *flexflow_lora_linear_config_get_cache_folder(
+    flexflow_lora_linear_config_t handle_);
+
+char const *flexflow_lora_linear_config_get_peft_model_id(
+    flexflow_lora_linear_config_t handle_);
+
+int flexflow_lora_linear_config_get_rank(flexflow_lora_linear_config_t handle_);
+
+float flexflow_lora_linear_config_get_lora_alpha(
+    flexflow_lora_linear_config_t handle_);
+
+float flexflow_lora_linear_config_get_lora_dropout(
+    flexflow_lora_linear_config_t handle_);
+
+bool flexflow_lora_linear_config_get_trainable(
+    flexflow_lora_linear_config_t handle_);
+
+bool flexflow_lora_linear_config_get_init_lora_weights(
+    flexflow_lora_linear_config_t handle_);
+
+char const **flexflow_lora_linear_config_get_target_modules(
+    flexflow_lora_linear_config_t handle_, int *num_target_modules);
+
+void flexflow_lora_linear_config_set_lora_alpha(
+    flexflow_lora_linear_config_t handle_, float value);
+
+void flexflow_lora_linear_config_set_lora_dropout(
+    flexflow_lora_linear_config_t handle_, float value);
+
+void flexflow_lora_linear_config_set_trainable(
+    flexflow_lora_linear_config_t handle_, bool value);
+
+void flexflow_lora_linear_config_set_init_lora_weights(
+    flexflow_lora_linear_config_t handle_, bool value);
+
 // -----------------------------------------------------------------------
 // PEFTModelID
 // -----------------------------------------------------------------------
diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h
index d3c74201d4..2fa2ebeaa3 100644
--- a/include/flexflow/ops/kernels/lora_linear_kernels.h
+++ b/include/flexflow/ops/kernels/lora_linear_kernels.h
@@ -22,7 +22,7 @@ struct LoraLinearWeight {
 struct LoraLinearModelState {
   LoraLinearWeight weights;
   LoraOptimizerConfig const *optimizer_config;
-  double lora_alpha;
+  float lora_alpha;
 };
 
 class LoraLinearMeta : public OpMeta {
diff --git a/include/flexflow/ops/lora_linear_params.h b/include/flexflow/ops/lora_linear_params.h
index a7afd48015..1887b008d6 100644
--- a/include/flexflow/ops/lora_linear_params.h
+++ b/include/flexflow/ops/lora_linear_params.h
@@ -7,6 +7,10 @@
 #include "flexflow/op_meta.h"
 #include "flexflow/operator.h"
 #include "flexflow/parallel_tensor.h"
+#include <filesystem>
+#include <fstream>
+#include <iostream>
+#include <nlohmann/json.hpp>
 
 namespace FlexFlow {
 
@@ -26,6 +30,9 @@ class LoraSGDOptimizerConfig : public LoraOptimizerConfig {
   friend std::ostream &operator<<(std::ostream &os,
                                   LoraSGDOptimizerConfig const &llc);
 
+  NLOHMANN_DEFINE_TYPE_INTRUSIVE(
+      LoraSGDOptimizerConfig, lr, momentum, nesterov, weight_decay)
+
 public:
   double lr = 0.001f;
   double momentum = 0.0f;
@@ -44,6 +51,9 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig {
   friend std::ostream &operator<<(std::ostream &os,
                                   LoraAdamOptimizerConfig const &llc);
 
+  NLOHMANN_DEFINE_TYPE_INTRUSIVE(
+      LoraAdamOptimizerConfig, alpha, beta1, beta2, weight_decay, epsilon)
+
 public:
   // Adam
   double alpha = 0.001f;
@@ -53,36 +63,59 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig {
   double epsilon = 1e-8;
 };
 
+// Serialization helpers
+template <typename T>
+void serialize_to_json_file(T const &obj, fs::path const &filepath);
+
+// Function to deserialize JSON from file and create object
+template <typename T>
+std::unique_ptr<T> deserialize_from_json_file(fs::path const &filepath);
+
 class LoraLinearConfig {
 public:
   static const LoraLinearConfig EmptyConfig;
-  LoraLinearConfig();
-  LoraLinearConfig(int _rank,
-                   bool _trainable = false,
-                   LoraOptimizerConfig *_optimizer_config = nullptr);
   LoraLinearConfig(std::string const &cache_folder_,
                    std::string const &peft_model_id_,
                    bool trainable_ = false,
-                   LoraOptimizerConfig *optimizer_config_ = nullptr);
+                   LoraOptimizerConfig *optimizer_config_ = nullptr,
+                   bool init_lora_weights_ = false,
+                   int rank_ = 8,
+                   float lora_alpha_ = 8.0f,
+                   float lora_dropout_ = 0.0f,
+                   std::vector<std::string> const &target_modules_ = {});
+  // constructor used to support std::unordered_map
+  LoraLinearConfig();
   friend bool operator==(LoraLinearConfig const &lhs,
                          LoraLinearConfig const &rhs);
   friend std::ostream &operator<<(std::ostream &os,
                                   LoraLinearConfig const &llc);
 
-public:
+  NLOHMANN_DEFINE_TYPE_INTRUSIVE(LoraLinearConfig,
+                                 cache_folder,
+                                 peft_model_id,
+                                 rank,
+                                 lora_alpha,
+                                 lora_dropout,
+                                 target_modules,
+                                 trainable,
+                                 init_lora_weights)
+
+  std::string cache_folder;
+  // Huggingface model ID (for download and/or upload)
+  std::string peft_model_id;
+  // Lora parameters
   int rank;
+  float lora_alpha;
+  float lora_dropout;
+  std::vector<std::string> target_modules;
+  // Training parameters
   // whether the weights are trainable (fine-tuning scenario) or not
   // (inference-only). If set to true, allocate space for the gradients
   bool trainable = false;
   LoraOptimizerConfig *optimizer_config;
-  std::string cache_folder;
-  // Huggingface
-  std::string peft_model_id;
-  int lora_alpha;
-  float lora_dropout;
-  std::vector<std::string> target_modules;
-  // whether to load weights from file, instead of initializing them randomly
-  bool load_weights_from_file;
+  // whether to initialize weights randomly (instead of attempting to load them
+  // from file)
+  bool init_lora_weights;
 };
 
 class LoraLinearParams {
diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc
index d61db6f3bd..09131c625a 100644
--- a/inference/peft/peft.cc
+++ b/inference/peft/peft.cc
@@ -253,7 +253,8 @@ void FlexFlow::top_level_task(Task const *task,
                               : LoraLinearConfig(file_paths.cache_folder_path,
                                                  peft_model_name,
                                                  true /*trainable*/,
-                                                 optim_config);
+                                                 optim_config,
+                                                 false /*init_lora_weights*/);
 
   GenerationConfig generationConfig(do_sample, temperature, topp);
   RequestManager *rm = RequestManager::get_request_manager();
diff --git a/inference/python/ff_peft.py b/inference/python/ff_peft.py
index 657748c6a9..b367272aca 100644
--- a/inference/python/ff_peft.py
+++ b/inference/python/ff_peft.py
@@ -59,21 +59,21 @@ def get_configs():
             "peft_weight_reserve_space_size": 1024,  # 1GB
             "profiling": False,
             "inference_debugging": True,
-            "fusion": True,
+            "fusion": False,
         }
         model_configs = {
             # required parameters
             "base_model": "JackFram/llama-160m",
-            "peft_model_ids": [
-                "goliaro/llama-160m-lora",
-            ],
+            "inference_peft_model_id": "goliaro/llama-160m-lora",
+            "finetuning_peft_model_id": "goliaro/llama-160m-lora",
             # optional parameters
             "cache_path": "",
             "refresh_cache": False,
-            "full_precision": False,
+            "full_precision": True,
             "prompt": "",
             "finetuning_dataset": os.path.join(
-                os.path.dirname(os.path.abspath(__file__)), "../prompt/peft.json"
+                os.path.dirname(os.path.abspath(__file__)),
+                "../prompt/peft_dataset.json",
             ),
             "output_file": "",
         }
@@ -100,8 +100,38 @@ def main():
         refresh_cache=configs.refresh_cache,
         output_file=configs.output_file,
     )
-    for peft_model_id in configs.peft_model_ids:
-        llm.add_peft(peft_model_id)
+    # Add inference and/or finetuning lora
+    lora_inference_config = None
+    lora_finetuning_config = None
+    if len(configs.prompt) > 0:
+        lora_inference_config = ff.LoraLinearConfig(
+            llm.cache_path, configs.inference_peft_model_id
+        )
+        llm.add_peft(lora_inference_config)
+    if len(configs.finetuning_dataset) > 0:
+        # lora_finetuning_config = ff.LoraLinearConfig(
+        #     llm.cache_path,
+        #     configs.finetuning_peft_model_id,
+        #     target_modules=["down_proj"],
+        #     rank=16,
+        #     lora_alpha=16,
+        #     trainable=True,
+        #     init_lora_weights=True,
+        #     optimizer_type=ff.OptimizerType.OPTIMIZER_TYPE_SGD,
+        # )
+        lora_finetuning_config = ff.LoraLinearConfig(
+            llm.cache_path,
+            configs.inference_peft_model_id,
+            trainable=True,
+            optimizer_type=ff.OptimizerType.OPTIMIZER_TYPE_SGD,
+            optimizer_kwargs={
+                "learning_rate": 1.0,
+                "momentum": 0.0,
+                "weight_decay": 0.0,
+                "nesterov": False,
+            },
+        )
+        llm.add_peft(lora_finetuning_config)
 
     # Compile the LLM for inference and load the weights into memory
     generation_config = ff.GenerationConfig(
@@ -109,10 +139,10 @@ def main():
     )
     llm.compile(
         generation_config,
-        enable_peft_finetuning = (len(configs.finetuning_dataset) > 0),
+        enable_peft_finetuning=(len(configs.finetuning_dataset) > 0),
         max_requests_per_batch=1,
         max_seq_length=256,
-        max_tokens_per_batch=64,
+        max_tokens_per_batch=128,
     )
 
     llm.start_server()
@@ -123,21 +153,24 @@ def main():
         prompts = [s for s in json.load(open(configs.prompt))]
         inference_requests = [
             ff.Request(
-                ff.RequestType.REQ_INFERENCE, prompt=prompt, max_sequence_length=128
+                ff.RequestType.REQ_INFERENCE,
+                prompt=prompt,
+                max_sequence_length=128,
+                peft_model_id=llm.get_ff_peft_id(lora_inference_config),
             )
             for prompt in prompts
         ]
         requests += inference_requests
     # Finetuning
     if len(configs.finetuning_dataset) > 0:
-        for peft_model_id in configs.peft_model_ids:
-            finetuning_request = ff.Request(
-                ff.RequestType.REQ_FINETUNING,
-                max_sequence_length=128,
-                peft_model_id=llm.get_ff_peft_id(peft_model_id),
-                dataset_filepath=configs.finetuning_dataset,
-            )
-            requests.append(finetuning_request)
+        finetuning_request = ff.Request(
+            ff.RequestType.REQ_FINETUNING,
+            max_sequence_length=128,
+            peft_model_id=llm.get_ff_peft_id(lora_finetuning_config),
+            dataset_filepath=configs.finetuning_dataset,
+            max_training_steps=2,
+        )
+        requests.append(finetuning_request)
 
     llm.generate(requests)
 
diff --git a/inference/utils/download_peft_model.py b/inference/utils/download_peft_model.py
index 596612d8d7..38dd577574 100644
--- a/inference/utils/download_peft_model.py
+++ b/inference/utils/download_peft_model.py
@@ -9,7 +9,10 @@ def parse_args():
         "--base_model_name", type=str, help="Name of the model to download"
     )
     parser.add_argument(
-        "peft_model_ids", type=str, nargs="+", help="Name of the PEFT model(s) to download"
+        "peft_model_ids",
+        type=str,
+        nargs="+",
+        help="Name of the PEFT model(s) to download",
     )
     parser.add_argument(
         "--cache-folder",
@@ -45,7 +48,6 @@ def main(args):
     else:
         data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF)
 
-    
     for data_type in data_types:
         llm = ff.LLM(
             args.base_model_name,
@@ -54,7 +56,8 @@ def main(args):
             refresh_cache=args.refresh_cache,
         )
         for peft_model_id in args.peft_model_ids:
-            llm.add_peft(peft_model_id)
+            lora_config = ff.LoraLinearConfig(llm.cache_path, peft_model_id)
+            llm.add_peft(lora_config)
         llm.download_hf_weights_if_needed()
         llm.download_hf_config()
         llm.download_hf_tokenizer_if_needed()
diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py
index aa414f74d7..0f41b5235c 100644
--- a/python/flexflow/core/flexflow_cffi.py
+++ b/python/flexflow/core/flexflow_cffi.py
@@ -29,6 +29,7 @@
     MetricsType,
     InferenceMode,
     RequestType,
+    OptimizerType,
     ModelType,
     OpType,
     ParameterSyncType,
@@ -1599,23 +1600,28 @@ def register_ssm_model(self, model):
 
     def set_max_requests_per_batch(self, max_requests):
         return ffc().flexflow_request_manager_set_max_requests_per_batch(
-            self.handle, max_requests)
-    
+            self.handle, max_requests
+        )
+
     def set_max_tokens_per_batch(self, max_tokens):
         return ffc().flexflow_request_manager_set_max_tokens_per_batch(
-            self.handle, max_tokens)
-    
+            self.handle, max_tokens
+        )
+
     def set_max_spec_tree_token_num(self, max_tokens):
         return ffc().flexflow_request_manager_set_max_spec_tree_token_num(
-            self.handle, max_tokens)
-    
+            self.handle, max_tokens
+        )
+
     def set_max_sequence_length(self, max_length):
         return ffc().flexflow_request_manager_set_max_sequence_length(
-            self.handle, max_length)
-    
+            self.handle, max_length
+        )
+
     def set_enable_peft_finetuning(self, enable_peft_finetuning):
         return ffc().flexflow_request_manager_set_enable_peft_finetuning(
-            self.handle, enable_peft_finetuning)
+            self.handle, enable_peft_finetuning
+        )
 
     def start_server(self, model):
         return ffc().flexflow_request_manager_start_background_server(
@@ -1742,20 +1748,206 @@ def __init__(self, text: str = None, tokens: list = None):
 
 
 class LoraLinearConfig(object):
-    __slots__ = ["handle", "_handle"]
-
     def __init__(
         self,
-        cache_folder,
-        peft_model_id,
+        cache_folder: str,
+        peft_model_id: str,
+        trainable: bool = False,
+        init_lora_weights: bool = False,
+        rank: int = 8,
+        lora_alpha: float = 8.0,
+        lora_dropout: float = 0.0,
+        target_modules: List[str] = [],
+        optimizer_type: OptimizerType = OptimizerType.OPTIMIZER_TYPE_NONE,
+        optimizer_kwargs: dict = {},
     ):
-        c_cache_folder = get_c_name(cache_folder)
-        peft_model_id = get_c_name(peft_model_id)
+        self.ff_initialized = False
+        self._cache_folder = cache_folder
+        self._peft_model_id = peft_model_id
+        self._trainable = trainable
+        self._init_lora_weights = init_lora_weights
+        self._rank = rank
+        self._lora_alpha = lora_alpha
+        self._lora_dropout = lora_dropout
+        self._target_modules = target_modules
+        self.optimizer_type = optimizer_type
+        self.optimizer_kwargs = optimizer_kwargs
+
+        if trainable:
+            if (
+                optimizer_type != OptimizerType.OPTIMIZER_TYPE_SGD
+                and optimizer_type != OptimizerType.OPTIMIZER_TYPE_ADAM
+            ):
+                raise ValueError(
+                    "Please specify optimizer to be used to train LoRA module. Supported optimizers: SGD and Adam"
+                )
+            if init_lora_weights and len(target_modules) == 0:
+                raise ValueError(
+                    "Please specify target modules to be used to train LoRA module"
+                )
+        else:
+            if init_lora_weights:
+                raise ValueError(
+                    "LORA weights initialization from scratch not supported in inference model"
+                )
+
+        if rank < 1 or lora_alpha <= 0 or lora_dropout > 1 or lora_dropout < 0.0:
+            raise ValueError(
+                "Rank must be >= 1, lora_alpha must be > 0, lora_dropout in interval: [0.0, 1.0]"
+            )
+
+    def ff_compile(self):
+        c_cache_folder = get_c_name(os.path.expanduser(self.cache_folder))
+        peft_model_id = get_c_name(self.peft_model_id)
+        c_target_modules = [
+            get_c_name(target_module) for target_module in self.target_modules
+        ]
+        c_optimizer_type = enum_to_int(OptimizerType, self.optimizer_type)
+        # SGD optional optimizer args
+        sgd_learning_rate = self.optimizer_kwargs.get("learning_rate", 0.001)
+        sgd_momentum = self.optimizer_kwargs.get("momentum", 0.0)
+        sgd_nesterov = self.optimizer_kwargs.get("nesterov", False)
+        sgd_weight_decay = self.optimizer_kwargs.get("weight_decay", 0.0)
+        # Adam optional optimizer args
+        adam_alpha = self.optimizer_kwargs.get("alpha", 0.001)
+        adam_beta1 = self.optimizer_kwargs.get("beta1", 0.9)
+        adam_beta2 = self.optimizer_kwargs.get("beta2", 0.999)
+        adam_weight_decay = self.optimizer_kwargs.get("weight_decay", 0.0)
+        adam_epsilon = self.optimizer_kwargs.get("epsilon", 1e-8)
         self.handle = ffc().flexflow_lora_linear_config_create(
             c_cache_folder,
             peft_model_id,
+            self.trainable,
+            self.init_lora_weights,
+            self.rank,
+            self.lora_alpha,
+            self.lora_dropout,
+            len(self.target_modules),
+            c_target_modules,
+            c_optimizer_type,
+            sgd_learning_rate,
+            sgd_momentum,
+            sgd_nesterov,
+            sgd_weight_decay,
+            adam_alpha,
+            adam_beta1,
+            adam_beta2,
+            adam_weight_decay,
+            adam_epsilon,
         )
         self._handle = ffi.gc(self.handle, ffc().flexflow_lora_linear_config_destroy)
+        self.ff_initialized = True
+
+    @property
+    def cache_folder(self):
+        if self.ff_initialized:
+            c_cache_folder = ffc().flexflow_lora_linear_config_get_cache_folder(
+                self.handle
+            )
+            return ffi.string(c_cache_folder).decode("utf-8")
+        else:
+            return self._cache_folder
+
+    @property
+    def peft_model_id(self):
+        if self.ff_initialized:
+            c_peft_model_id = ffc().flexflow_lora_linear_config_get_peft_model_id(
+                self.handle
+            )
+            return ffi.string(c_peft_model_id).decode("utf-8")
+        else:
+            return self._peft_model_id
+
+    @property
+    def rank(self):
+        if self.ff_initialized:
+            return ffc().flexflow_lora_linear_config_get_rank(self.handle)
+        else:
+            return self._rank
+
+    @property
+    def lora_alpha(self):
+        if self.ff_initialized:
+            return ffc().flexflow_lora_linear_config_get_lora_alpha(self.handle)
+        else:
+            return self._lora_alpha
+
+    @property
+    def lora_dropout(self):
+        if self.ff_initialized:
+            return ffc().flexflow_lora_linear_config_get_lora_dropout(self.handle)
+        else:
+            return self._lora_dropout
+
+    @property
+    def trainable(self):
+        if self.ff_initialized:
+            return ffc().flexflow_lora_linear_config_get_trainable(self.handle)
+        else:
+            return self._trainable
+
+    @property
+    def init_lora_weights(self):
+        if self.ff_initialized:
+            return ffc().flexflow_lora_linear_config_get_init_lora_weights(self.handle)
+        else:
+            return self._init_lora_weights
+
+    @property
+    def target_modules(self):
+        if self.ff_initialized:
+            num_target_modules = ffi.new("int *")
+            c_target_modules = ffc().flexflow_lora_linear_config_get_target_modules(
+                self.handle, num_target_modules
+            )
+            target_modules = []
+            for i in range(num_target_modules[0]):
+                target_modules.append(ffi.string(c_target_modules[i]).decode("utf-8"))
+            return target_modules
+        else:
+            return self._target_modules
+
+    @cache_folder.setter
+    def cache_folder(self, value: str):
+        self._cache_folder = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_cache_folder(self.handle, value)
+
+    @peft_model_id.setter
+    def peft_model_id(self, value: str):
+        self._peft_model_id = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_peft_model_id(self.handle, value)
+
+    @rank.setter
+    def rank(self, value: int):
+        self._rank = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_rank(self.handle, value)
+
+    @lora_alpha.setter
+    def lora_alpha(self, value: float):
+        self._lora_alpha = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_lora_alpha(self.handle, value)
+
+    @lora_dropout.setter
+    def lora_dropout(self, value: float):
+        self._lora_dropout = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_lora_dropout(self.handle, value)
+
+    @trainable.setter
+    def trainable(self, value: bool):
+        self._trainable = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_trainable(self.handle, value)
+
+    @init_lora_weights.setter
+    def init_lora_weights(self, value: bool):
+        self._init_lora_weights = value
+        if self.ff_initialized:
+            ffc().flexflow_lora_linear_config_set_init_lora_weights(self.handle, value)
 
 
 # -----------------------------------------------------------------------
@@ -1781,6 +1973,7 @@ def no_id_handle():
             PEFTModelID.__no_id_h = ffc().flexflow_peft_model_id_no_id()
         return PEFTModelID.__no_id_h
 
+
 # -----------------------------------------------------------------------
 # Request
 # -----------------------------------------------------------------------
@@ -2561,9 +2754,10 @@ def residual_layer_norm(
             c_name,
         )
         self.add_layer(OpType.RESIDUAL_LAYERNORM, name)
-        return Tensor(
-            handles_array[0], owner_op_type=OpType.RESIDUAL_LAYERNORM
-        ), Tensor(handles_array[1], owner_op_type=OpType.RESIDUAL_LAYERNORM)
+        return (
+            Tensor(handles_array[0], owner_op_type=OpType.RESIDUAL_LAYERNORM),
+            Tensor(handles_array[1], owner_op_type=OpType.RESIDUAL_LAYERNORM),
+        )
 
     def add_bias_residual_layer_norm(
         self,
@@ -2614,9 +2808,10 @@ def add_bias_residual_layer_norm(
             c_name,
         )
         self.add_layer(OpType.ADD_BIAS_RESIDUAL_LAYERNORM, name)
-        return Tensor(
-            handles_array[0], owner_op_type=OpType.ADD_BIAS_RESIDUAL_LAYERNORM
-        ), Tensor(handles_array[1], owner_op_type=OpType.ADD_BIAS_RESIDUAL_LAYERNORM)
+        return (
+            Tensor(handles_array[0], owner_op_type=OpType.ADD_BIAS_RESIDUAL_LAYERNORM),
+            Tensor(handles_array[1], owner_op_type=OpType.ADD_BIAS_RESIDUAL_LAYERNORM),
+        )
 
     def sigmoid_silu_multi(self, input1, input2, name=None):
         c_name = get_c_name(name)
@@ -3928,8 +4123,9 @@ def residual_rms_norm(
             c_name,
         )
         self.add_layer(OpType.RESIDUAL_RMS_NORM, name)
-        return Tensor(handles_array[0], owner_op_type=OpType.RESIDUAL_RMS_NORM), Tensor(
-            handles_array[1], owner_op_type=OpType.RESIDUAL_RMS_NORM
+        return (
+            Tensor(handles_array[0], owner_op_type=OpType.RESIDUAL_RMS_NORM),
+            Tensor(handles_array[1], owner_op_type=OpType.RESIDUAL_RMS_NORM),
         )
 
     def arg_top_k(self, input, k, sorted, speculative_decoding, name=None):
@@ -4026,9 +4222,7 @@ def argmax(self, input, beam_search, name=None):
         return Tensor(handle, owner_op_type=OpType.ARGMAX)
 
     def add_lora_layer(self, peft_config):
-        handle = ffc().flexflow_model_add_lora_layer(self.handle, peft_config.handle)
-        return handle
-        # self.add_layer(OpType.LORA, name)
+        return ffc().flexflow_model_add_lora_layer(self.handle, peft_config.handle)
 
     def reset_metrics(self):
         """Reset performance metrics.
@@ -4459,9 +4653,13 @@ def generate(self, requests_list: List[Request]):
             request.max_sequence_length for request in requests_list
         ]
         peft_model_ids = [
-            (request.peft_model_id 
-             if request.peft_model_id is not None else PEFTModelID.no_id_handle()) 
-             for request in requests_list]
+            (
+                request.peft_model_id
+                if request.peft_model_id is not None
+                else PEFTModelID.no_id_handle()
+            )
+            for request in requests_list
+        ]
         dataset_filepaths = [
             get_c_name(request.dataset_filepath) for request in requests_list
         ]
diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py
index 8ee20f9547..96f0258572 100644
--- a/python/flexflow/serve/models/llama.py
+++ b/python/flexflow/serve/models/llama.py
@@ -62,7 +62,7 @@ def __init__(
         # self.llama_config.max_num_tokens = max_tokens_per_batch
         self.weights_filepath = weights_filepath
         self.tokenizer_filepath = tokenizer_filepath
-        self.maxint = 2**31 - 1
+        self.maxint = 2 ** 31 - 1
         max_verify_tokens_per_batch = (
             max_tokens_per_batch + self.llama_config.max_spec_tree_token_num
         )
@@ -258,8 +258,4 @@ def convert_hf_model(model, dst_folder):
         os.makedirs(dst_folder, exist_ok=True)
         for name, params in model.named_parameters():
             name = FlexFlowLLAMA.convert_hf_weight_name(name)
-            if "lm_head" in name:
-                print("Encountered lm_head, shape", params.detach().cpu().numpy().shape)
-            if "embed_tokens" in name:
-                print("Encountered embed_tokens, shape", params.detach().cpu().numpy().shape)
             params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}")
diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py
index 30144e1174..319505794a 100644
--- a/python/flexflow/serve/serve.py
+++ b/python/flexflow/serve/serve.py
@@ -28,16 +28,14 @@
 )
 from flexflow.core import *
 from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
-from peft import PeftModel, PeftConfig
+from peft import PeftModel, PeftConfig, LoraConfig
 from huggingface_hub import HfApi
 import torch, shutil, hashlib, json, gc
 from typing import Union, List
 
 
 class _SupportedModels:
-    def __init__(
-        self,
-    ):
+    def __init__(self,):
         self.supported_models = {
             "LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
             "LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
@@ -109,37 +107,60 @@ def __del__(self):
         if type(self) == LLM and self.rm is not None:
             self.rm.stop_server()
 
-    def add_peft(self, peft_model_id: str):
-        """Add a previously created PEFT adapter to the LLM. The PEFT model should already exist locally or be available on HuggingFace"""
-        peft_config = PeftConfig.from_pretrained(peft_model_id)
-        peft_type = peft_config.peft_type
-        if peft_type != "LORA":
-            raise RuntimeError(f"PEFT type {peft_type} not yet supported in FlexFlow")
+    def add_peft(self, lora_config: LoraLinearConfig):
+        """Add a PEFT adapter to the LLM"""
+        if lora_config is None:
+            raise ValueError("lora_config cannot be None")
+        if len(lora_config.peft_model_id or "") == 0:
+            raise ValueError("PEFT model id cannot be empty")
+        # Inference (trainable=False): LoRA model should already exist in huggingface. Any changes of parameters from original model are ignored
+        # Training (trainable=True): Either an existing model (init_lora_weights=False) or a new one (init_lora_weights=True)
+
+        if lora_config.trainable == False or not lora_config.init_lora_weights:
+            peft_config = PeftConfig.from_pretrained(lora_config.peft_model_id)
+        else:
+            peft_config = LoraConfig(
+                peft_type="LORA",
+                base_model_name_or_path=self.model_name,
+                r=lora_config.rank,
+                target_modules=lora_config.target_modules,
+                lora_alpha=lora_config.lora_alpha,
+                lora_dropout=lora_config.lora_dropout,
+                init_lora_weights=lora_config.init_lora_weights,
+            )
+        if peft_config.peft_type != "LORA":
+            raise RuntimeError(
+                f"PEFT type {peft_config.peft_type} not yet supported in FlexFlow"
+            )
         if "base_model_name_or_path" not in peft_config.to_dict():
             raise ValueError(
-                f"PEFT model {peft_model_id} does not have an associated base model"
+                f"PEFT model {lora_config.peft_model_id} does not have an associated base model"
             )
         if peft_config.base_model_name_or_path != self.model_name:
             raise RuntimeError(
                 f"Attempting to add PEFT with base model name {peft_config.base_model_name_or_path} to LLM {self.model_name}"
             )
-        peft_dict = {
+
+        self.pefts[lora_config] = {
             "peft_config": peft_config,
-            "peft_type": peft_type,
+            "peft_type": peft_config.peft_type,
         }
-        self.pefts[peft_model_id] = peft_dict
 
-    def get_ff_peft_id(self, peft_model_id: str) -> PEFTModelID:
-        if peft_model_id not in self.pefts:
+    def get_ff_peft_id(self, lora_config: LoraLinearConfig) -> PEFTModelID:
+        if lora_config is None:
+            raise ValueError("lora_config cannot be None")
+        if len(lora_config.peft_model_id or "") == 0:
+            raise ValueError("PEFT model id cannot be empty")
+        if lora_config not in self.pefts:
             raise ValueError(
-                f"PEFT {peft_model_id} not registered with LLM {self.model_name}"
+                f"PEFT {lora_config} not registered with LLM {self.model_name}"
             )
-        peft_dict = self.pefts[peft_model_id]
-        if "ff_peft_model_id" not in peft_dict:
+        if "ff_peft_model_id" not in self.pefts[lora_config]:
             raise RuntimeError(
-                f"Attempting to run PEFT {peft_model_id} before compiling LLM {self.model_name}"
+                f"Attempting to run PEFT {lora_config} before compiling LLM {self.model_name}"
             )
-        return peft_dict["ff_peft_model_id"]
+
+        return self.pefts[lora_config]["ff_peft_model_id"]
 
     def download_hf_config(self):
         """Save the HuggingFace model configs to a json file. Useful mainly to run the C++ inference code."""
@@ -153,12 +174,11 @@ def download_hf_config(self):
         self.hf_config.to_json_file(config_path)
 
         # Save PEFT configs if the LLM has any registered PEFTs
-        for peft_model_id, peft_dict in self.pefts.items():
+        for ff_peft_config, peft_dict in self.pefts.items():
             peft_config = peft_dict["peft_config"]
+            peft_model_id = ff_peft_config.peft_model_id
             peft_config_dir = os.path.join(
-                os.path.expanduser(self.cache_path),
-                "configs",
-                peft_model_id.lower(),
+                os.path.expanduser(self.cache_path), "configs", peft_model_id.lower()
             )
             os.makedirs(peft_config_dir, exist_ok=True)
             peft_config_path = os.path.join(peft_config_dir, "config.json")
@@ -264,14 +284,15 @@ def convert_peft_model(hf_peft_model, peft_type, weights_path):
                     params.detach().cpu().numpy().tofile(f"{weights_path}/{name}")
 
         def download_peft_weights():
-            for peft_model_id, peft_dict in self.pefts.items():
+            for ff_peft_config, peft_dict in self.pefts.items():
                 peft_config = peft_dict["peft_config"]
                 peft_type = peft_dict["peft_type"]
+                peft_model_id = ff_peft_config.peft_model_id
 
                 weights_path = get_weights_path(peft_model_id)
                 refresh_cache_if_needed(peft_model_id)
-                ff_revision, ff_revision_file, latest_revision = (
-                    self.__get_revision_hashes(peft_model_id, weights_path)
+                ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes(
+                    peft_model_id, weights_path
                 )
 
                 if ff_revision != latest_revision:
@@ -306,9 +327,7 @@ def download_hf_tokenizer_if_needed(self):
 
         # Use local cache, or download new version
         self.tokenizer_path = os.path.join(
-            os.path.expanduser(self.cache_path),
-            "tokenizers",
-            self.model_name.lower(),
+            os.path.expanduser(self.cache_path), "tokenizers", self.model_name.lower()
         )
         if self.refresh_cache:
             print(
@@ -423,11 +442,8 @@ def compile(
         )
 
         # Add PEFT layer if registered
-        for peft_model_id, peft_dict in self.pefts.items():
-            # ff_peft_config = peft_dict["ff_peft_config"]
-            ff_peft_config = LoraLinearConfig(
-                os.path.expanduser(self.cache_path), peft_model_id
-            )
+        for ff_peft_config, peft_dict in self.pefts.items():
+            ff_peft_config.ff_compile()
             ff_peft_model_id = self.model.ffmodel.add_lora_layer(ff_peft_config)
             peft_dict["ff_peft_model_id"] = ff_peft_model_id
 
@@ -439,8 +455,7 @@ def compile(
 
         self.rm.set_max_spec_tree_token_num(
             model_configs.max_spec_tree_token_num
-            if "max_spec_tree_token_num"
-            in model_configs.__dict__
+            if "max_spec_tree_token_num" in model_configs.__dict__
             else 20
         )
 
@@ -550,13 +565,7 @@ def __init__(
         :param output_file: Path to the output file. If left blank, the output will not be written to file, defaults to ""
         :type output_file: str, optional
         """
-        super().__init__(
-            model_name,
-            data_type,
-            cache_path,
-            refresh_cache,
-            output_file,
-        )
+        super().__init__(model_name, data_type, cache_path, refresh_cache, output_file)
 
     def compile(
         self,
diff --git a/python/flexflow/type.py b/python/flexflow/type.py
index ac6975b4fd..0f4726837c 100644
--- a/python/flexflow/type.py
+++ b/python/flexflow/type.py
@@ -46,6 +46,12 @@ class LossType(Enum):
     LOSS_IDENTITY = 54
 
 
+class OptimizerType(Enum):
+    OPTIMIZER_TYPE_NONE = 60
+    OPTIMIZER_TYPE_SGD = 61
+    OPTIMIZER_TYPE_ADAM = 62
+
+
 class CompMode(Enum):
     TRAINING = 70
     INFERENCE = 71
@@ -152,10 +158,12 @@ class OpType(Enum):
     RESIDUAL_RMS_NORM = 2305
     RESIDUAL_LAYERNORM = 2306
 
+
 class RequestType(Enum):
     REQ_INFERENCE = 4001
     REQ_FINETUNING = 4002
 
+
 def enum_to_int(enum, enum_item):
     for item in enum:
         if enum_item == item:
diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc
index 76ca5053d6..3ba6398db1 100644
--- a/src/c/flexflow_c.cc
+++ b/src/c/flexflow_c.cc
@@ -67,6 +67,11 @@ class FFCObjectWrapper {
   FF_NEW_OPAQUE_WRAPPER(flexflow_request_manager_t, RequestManager *);
   FF_NEW_OPAQUE_WRAPPER(flexflow_file_data_loader_t, FileDataLoader *);
   FF_NEW_OPAQUE_WRAPPER(flexflow_generation_result_t, GenerationResult *);
+  // FF_NEW_OPAQUE_WRAPPER(flexflow_lora_optimizer_config_t, LoraOptimizerConfig
+  // *); FF_NEW_OPAQUE_WRAPPER(flexflow_lora_sgd_optimizer_config_t,
+  //                       LoraSGDOptimizerConfig *);
+  // FF_NEW_OPAQUE_WRAPPER(flexflow_lora_adam_optimizer_config_t,
+  //                       LoraAdamOptimizerConfig *);
   FF_NEW_OPAQUE_WRAPPER(flexflow_lora_linear_config_t, LoraLinearConfig *);
   FF_NEW_OPAQUE_WRAPPER(flexflow_peft_model_id_t, PEFTModelID *);
 };
@@ -2804,20 +2809,102 @@ void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_,
   handle->load_weights(model);
 }
 
+// // -----------------------------------------------------------------------
+// // LoraSGDOptimizerConfig
+// // -----------------------------------------------------------------------
+
+// flexflow_lora_sgd_optimizer_config_t
+// flexflow_lora_sgd_optimizer_config_create(
+//     double lr, double momentum, bool nesterov, bool weight_decay) {
+//   LoraSGDOptimizerConfig *handle =
+//       new LoraSGDOptimizerConfig(lr, momentum, nesterov, weight_decay);
+//   DEBUG_PRINT("[LoraSGDOptimizerConfig] new %p", handle);
+//   return FFCObjectWrapper::wrap(handle);
+// }
+
+// void flexflow_lora_sgd_optimizer_config_destroy(
+//     flexflow_lora_sgd_optimizer_config_t handle_) {
+//   LoraSGDOptimizerConfig *handle = FFCObjectWrapper::unwrap(handle_);
+//   DEBUG_PRINT("[LoraSGDOptimizerConfig] delete %p", handle);
+//   delete handle;
+// }
+
+// // -----------------------------------------------------------------------
+// // LoraAdamOptimizerConfig
+// // -----------------------------------------------------------------------
+
+// flexflow_lora_adam_optimizer_config_t
+//     flexflow_lora_adam_optimizer_config_create(double alpha,
+//                                                double beta1,
+//                                                double beta2,
+//                                                double weight_decay,
+//                                                double epsilon) {
+//   LoraAdamOptimizerConfig *handle =
+//       new LoraAdamOptimizerConfig(alpha, beta1, beta2, weight_decay,
+//       epsilon);
+//   DEBUG_PRINT("[LoraAdamOptimizerConfig] new %p", handle);
+//   return FFCObjectWrapper::wrap(handle);
+// }
+
+// void flexflow_lora_adam_optimizer_config_destroy(
+//     flexflow_lora_adam_optimizer_config_t handle_) {
+//   LoraAdamOptimizerConfig *handle = FFCObjectWrapper::unwrap(handle_);
+//   DEBUG_PRINT("[LoraAdamOptimizerConfig] delete %p", handle);
+//   delete handle;
+// }
+
 // -----------------------------------------------------------------------
 // LoraLinearConfig
 // -----------------------------------------------------------------------
 
 flexflow_lora_linear_config_t
     flexflow_lora_linear_config_create(char const *cache_folder_,
-                                       char const *peft_model_id_) {
+                                       char const *peft_model_id_,
+                                       bool trainable,
+                                       bool init_lora_weights,
+                                       int rank,
+                                       float lora_alpha,
+                                       float lora_dropout,
+                                       int num_target_modules,
+                                       char const **target_modules_,
+                                       enum OptimizerType optimizer_type,
+                                       float sgd_learning_rate,
+                                       float sgd_momentum,
+                                       bool sgd_nesterov,
+                                       float sgd_weight_decay,
+                                       float adam_alpha,
+                                       float adam_beta1,
+                                       float adam_beta2,
+                                       float adam_weight_decay,
+                                       float adam_epsilon) {
   assert(cache_folder_ != nullptr &&
          "Cannot convert nullptr char * to std::string");
   assert(peft_model_id_ != nullptr &&
          "Cannot convert nullptr char * to std::string");
   std::string const cache_folder(cache_folder_);
   std::string const peft_model_id(peft_model_id_);
-  LoraLinearConfig *handle = new LoraLinearConfig(cache_folder, peft_model_id);
+  LoraOptimizerConfig *optim_config = nullptr;
+  if (optimizer_type == OptimizerType::OPTIMIZER_TYPE_SGD) {
+    optim_config = new LoraSGDOptimizerConfig(
+        sgd_learning_rate, sgd_momentum, sgd_nesterov, sgd_weight_decay);
+  } else if (optimizer_type == OptimizerType::OPTIMIZER_TYPE_ADAM) {
+    optim_config = new LoraAdamOptimizerConfig(
+        adam_alpha, adam_beta1, adam_beta2, adam_weight_decay, adam_epsilon);
+  }
+  std::vector<std::string> target_modules;
+  for (int i = 0; i < num_target_modules; i++) {
+    std::string const target_module(target_modules_[i]);
+    target_modules.push_back(target_module);
+  }
+  LoraLinearConfig *handle = new LoraLinearConfig(cache_folder,
+                                                  peft_model_id,
+                                                  trainable,
+                                                  optim_config,
+                                                  init_lora_weights,
+                                                  rank,
+                                                  lora_alpha,
+                                                  lora_dropout,
+                                                  target_modules);
   DEBUG_PRINT("[LoraLinearConfig] new %p", handle);
   return FFCObjectWrapper::wrap(handle);
 }
@@ -2829,6 +2916,84 @@ void flexflow_lora_linear_config_destroy(
   delete peft_config;
 }
 
+char const *flexflow_lora_linear_config_get_cache_folder(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->cache_folder.c_str();
+}
+
+char const *flexflow_lora_linear_config_get_peft_model_id(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->peft_model_id.c_str();
+}
+
+int flexflow_lora_linear_config_get_rank(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->rank;
+}
+
+float flexflow_lora_linear_config_get_lora_alpha(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->lora_alpha;
+}
+
+float flexflow_lora_linear_config_get_lora_dropout(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->lora_dropout;
+}
+
+bool flexflow_lora_linear_config_get_trainable(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->trainable;
+}
+
+bool flexflow_lora_linear_config_get_init_lora_weights(
+    flexflow_lora_linear_config_t handle_) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  return handle->init_lora_weights;
+}
+
+char const **flexflow_lora_linear_config_get_target_modules(
+    flexflow_lora_linear_config_t handle_, int *num_target_modules) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  *num_target_modules = handle->target_modules.size();
+  static std::vector<char const *> target_modules_;
+  target_modules_.clear();
+  for (auto const &target_module : handle->target_modules) {
+    target_modules_.push_back(target_module.c_str());
+  }
+  return target_modules_.data();
+}
+
+void flexflow_lora_linear_config_set_lora_alpha(
+    flexflow_lora_linear_config_t handle_, float value) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  handle->lora_alpha = value;
+}
+
+void flexflow_lora_linear_config_set_lora_dropout(
+    flexflow_lora_linear_config_t handle_, float value) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  handle->lora_dropout = value;
+}
+
+void flexflow_lora_linear_config_set_trainable(
+    flexflow_lora_linear_config_t handle_, bool value) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  handle->trainable = value;
+}
+
+void flexflow_lora_linear_config_set_init_lora_weights(
+    flexflow_lora_linear_config_t handle_, bool value) {
+  LoraLinearConfig *handle = FFCObjectWrapper::unwrap(handle_);
+  handle->init_lora_weights = value;
+}
+
 // -----------------------------------------------------------------------
 // PEFTModelID
 // -----------------------------------------------------------------------
diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu
index 7a7712ba43..6c8a74b71b 100644
--- a/src/ops/kernels/lora_linear_kernels.cu
+++ b/src/ops/kernels/lora_linear_kernels.cu
@@ -246,7 +246,7 @@ void inference_kernel(LoraLinearMeta *m,
     // [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens]
     // Note that we use alpha in both places since we do
     // an in-place update for LoraLinear
-    double lora_alpha =
+    float lora_alpha =
         m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha;
     DT scaling_constant = (DT)(lora_alpha / rank);
     checkCUDA(cublasGemmEx(m->handle.blas,
@@ -341,7 +341,7 @@ void peft_bwd_kernel(LoraLinearMeta *m,
     LoraLinearWeight weight =
         m->model_state[bc->requestsInfo[i].peft_model_id].weights;
     int rank = weight.rank;
-    double lora_alpha =
+    float lora_alpha =
         m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha;
     DT scaling_constant = (DT)(lora_alpha / rank);
     // Compute LORA_B weight's gradient
diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc
index 702ad59aea..b489617ee5 100644
--- a/src/ops/lora_linear.cc
+++ b/src/ops/lora_linear.cc
@@ -56,7 +56,8 @@ PEFTModelID *FFModel::add_lora_layer(LoraLinearConfig const peft_config) {
          "Cannot add a LoRA layer if PEFT mode is not enabled");
   if (peft_config.target_modules.size() == 0) {
     printf("PEFT config does not contain any target module\n");
-    return nullptr;
+    std::cout << peft_config << std::endl;
+    assert(false);
   }
   PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++);
   peft_configs[*peft_model_id] = peft_config;
@@ -699,7 +700,6 @@ void LoraLinear::inference_task(Task const *task,
     }
     // input activation (intermediate)
     filename = dst_filepath.string() + ".low_rank_activation";
-    assert(num_tokens == 128);
     if (output.data_type == DT_FLOAT) {
       save_tensor(
           (float *)m->low_rank_activation, rank * num_tokens, filename.c_str());
@@ -821,6 +821,7 @@ void LoraLinear::peft_bwd_task(Task const *task,
     // weights, weights gradients
     fs::path dst_filepath_weights =
         get_dst_folder("weights", m->bwd_step, shard_id) / layername;
+    assert(m->model_state.size() >= 1 && "Model state empty!");
     for (auto it = m->model_state.begin(); it != m->model_state.end(); ++it) {
       PEFTModelID peft_model_id = it->first;
       LoraLinearWeight weight = m->model_state[peft_model_id].weights;
@@ -937,6 +938,35 @@ bool operator==(LoraLinearParams const &lhs, LoraLinearParams const &rhs) {
   return false;
 }
 
+fs::path create_unique_temp_directory() {
+  std::srand(static_cast<unsigned int>(std::time(nullptr)));
+
+  fs::path temp_dir = fs::temp_directory_path();
+  fs::path unique_path;
+
+  do {
+    std::string unique_name = "flexflow_tmp_" + std::to_string(std::rand());
+    unique_path = temp_dir / unique_name;
+  } while (fs::exists(unique_path));
+
+  fs::create_directory(unique_path);
+  return unique_path;
+}
+
+void serialize_string(Legion::Serializer &sez,
+                      std::string string_to_serialize) {
+  sez.serialize(string_to_serialize.length());
+  sez.serialize(string_to_serialize.c_str(), string_to_serialize.length());
+}
+
+std::string deserialize_string(Legion::Deserializer &dez) {
+  size_t string_size;
+  char buffer[4096] = {0};
+  dez.deserialize(string_size);
+  dez.deserialize(buffer, string_size);
+  return std::string(buffer);
+}
+
 void LoraLinear::serialize(Legion::Serializer &sez) const {
   sez.serialize(this->layer_guid.id);
   sez.serialize(this->layer_guid.transformer_layer_id);
@@ -946,40 +976,38 @@ void LoraLinear::serialize(Legion::Serializer &sez) const {
   for (auto const &kv : this->peft_configs) {
     // Serialize PEFTModelID
     sez.serialize(kv.first.id);
-    // Serialize LoraConfig's cache folder
-    sez.serialize(kv.second.cache_folder.length());
-    sez.serialize(kv.second.cache_folder.c_str(),
-                  kv.second.cache_folder.length());
-    // Serialize LoraConfig's peft model id
-    sez.serialize(kv.second.peft_model_id.length());
-    sez.serialize(kv.second.peft_model_id.c_str(),
-                  kv.second.peft_model_id.length());
-    // serialize whether we should expect an optimizer or not
-    sez.serialize(kv.second.trainable);
+
+    // Serialize LoraLinearConfig and OptimizerConfig to tmp folder
+    // 1. Create tmp dir and serialize it
+    fs::path unique_temp_dir = create_unique_temp_directory();
+    serialize_string(sez, unique_temp_dir.string());
+    // 2. Dump LoraLinearConfig to json file in tmp dir
+    std::string lora_config_filename = std::string("lora_linear_config_") +
+                                       std::to_string(kv.first.id) +
+                                       std::string(".json");
+    fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename;
+    serialize_to_json_file(kv.second, lora_config_json_filepath);
+    // 3. Dump optimizer to json file in tmp dir, and serialize optimizer type
+    std::string optimizer_filename = std::string("optimizer_config_") +
+                                     std::to_string(kv.first.id) +
+                                     std::string(".json");
+    fs::path optim_config_filepath = unique_temp_dir / optimizer_filename;
     assert((kv.second.trainable) == (kv.second.optimizer_config != nullptr));
     if (kv.second.trainable) {
-      // Serialize LoraConfig's optimizer config
       if (typeid(*kv.second.optimizer_config) ==
           typeid(LoraSGDOptimizerConfig)) {
         sez.serialize(OPTIMIZER_TYPE_SGD);
         LoraSGDOptimizerConfig const *sgd_config =
             static_cast<LoraSGDOptimizerConfig const *>(
                 kv.second.optimizer_config);
-        sez.serialize(sgd_config->lr);
-        sez.serialize(sgd_config->momentum);
-        sez.serialize(sgd_config->nesterov);
-        sez.serialize(sgd_config->weight_decay);
+        serialize_to_json_file(*sgd_config, optim_config_filepath);
       } else if (typeid(*kv.second.optimizer_config) ==
                  typeid(LoraAdamOptimizerConfig)) {
         sez.serialize(OPTIMIZER_TYPE_ADAM);
         LoraAdamOptimizerConfig const *adam_config =
             static_cast<LoraAdamOptimizerConfig const *>(
                 kv.second.optimizer_config);
-        sez.serialize(adam_config->alpha);
-        sez.serialize(adam_config->beta1);
-        sez.serialize(adam_config->beta2);
-        sez.serialize(adam_config->weight_decay);
-        sez.serialize(adam_config->epsilon);
+        serialize_to_json_file(*adam_config, optim_config_filepath);
       } else {
         assert(false && "Optimizer type not yet supported");
       }
@@ -1014,54 +1042,48 @@ Node LoraLinear::deserialize(FFModel &ff,
     size_t pid;
     dez.deserialize(pid);
     PEFTModelID peft_model_id(pid);
-
-    // Deserialize LoraConfig's cache folder
-    size_t string_size;
-    char buffer[4096] = {0};
-    dez.deserialize(string_size);
-    dez.deserialize(buffer, string_size);
-    std::string cache_folder = std::string(buffer);
-
-    // Deserialize LoraConfig's peft model id
-    string_size = 0;
-    memset(buffer, 0, 4096);
-    dez.deserialize(string_size);
-    dez.deserialize(buffer, string_size);
-    std::string peft_model_name = std::string(buffer);
-
-    bool trainable;
-    LoraOptimizerConfig *optimizer_config_ = nullptr;
-    OptimizerType type_;
-    dez.deserialize(trainable);
-    if (trainable) {
+    // Deserialize tmp folder containing LoraLinearConfig and optimizer config
+    fs::path unique_temp_dir = fs::path(deserialize_string(dez));
+    // 1. Deserialize LoraLinearConfig
+    std::string lora_config_filename = std::string("lora_linear_config_") +
+                                       std::to_string(pid) +
+                                       std::string(".json");
+    fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename;
+    std::unique_ptr<LoraLinearConfig> lora_linear_config =
+        deserialize_from_json_file<LoraLinearConfig>(lora_config_json_filepath);
+    // 2. Deserialize optimizer if needed
+    if (lora_linear_config->trainable) {
+      std::string optimizer_filename = std::string("optimizer_config_") +
+                                       std::to_string(pid) +
+                                       std::string(".json");
+      fs::path optim_config_filepath = unique_temp_dir / optimizer_filename;
+      OptimizerType type_;
       dez.deserialize(type_);
       if (type_ == OPTIMIZER_TYPE_SGD) {
-        double lr, momentum, weight_decay;
-        bool nesterov;
-        dez.deserialize(lr);
-        dez.deserialize(momentum);
-        dez.deserialize(nesterov);
-        dez.deserialize(weight_decay);
-        optimizer_config_ =
-            new LoraSGDOptimizerConfig(lr, momentum, nesterov, weight_decay);
+        std::unique_ptr<LoraSGDOptimizerConfig> sgd_optimizer_config =
+            deserialize_from_json_file<LoraSGDOptimizerConfig>(
+                optim_config_filepath);
+        lora_linear_config->optimizer_config =
+            dynamic_cast<LoraOptimizerConfig *>(sgd_optimizer_config.release());
       } else if (type_ == OPTIMIZER_TYPE_ADAM) {
-        double alpha, beta1, beta2, weight_decay, epsilon;
-        dez.deserialize(alpha);
-        dez.deserialize(beta1);
-        dez.deserialize(beta2);
-        dez.deserialize(weight_decay);
-        dez.deserialize(epsilon);
-        optimizer_config_ = new LoraAdamOptimizerConfig(
-            alpha, beta1, beta2, weight_decay, epsilon);
+        std::unique_ptr<LoraAdamOptimizerConfig> adam_optimizer_config =
+            deserialize_from_json_file<LoraAdamOptimizerConfig>(
+                optim_config_filepath);
+        lora_linear_config->optimizer_config =
+            dynamic_cast<LoraOptimizerConfig *>(
+                adam_optimizer_config.release());
       } else {
         printf("Optimizer type: %d\n", type_);
         assert(false && "Optimizer type not yet supported");
       }
     }
-    LoraLinearConfig lora_linear_config(
-        cache_folder, peft_model_name, trainable, optimizer_config_);
+    try {
+      fs::remove_all(unique_temp_dir);
+    } catch (fs::filesystem_error const &e) {
+      std::cerr << "Error removing tmp directory: " << e.what() << std::endl;
+    }
     params.peft_configs.emplace(
-        std::make_pair(peft_model_id, lora_linear_config));
+        std::make_pair(peft_model_id, *lora_linear_config));
   }
   dez.deserialize(name_len);
   dez.deserialize(name, name_len);
@@ -1115,7 +1137,7 @@ size_t hash<FlexFlow::LoraLinearParams>::operator()(
     hash_combine(key, kv.second.lora_alpha);
     hash_combine(key, kv.second.lora_dropout);
     hash_combine(key, kv.second.target_modules);
-    hash_combine(key, kv.second.load_weights_from_file);
+    hash_combine(key, kv.second.init_lora_weights);
   }
   return key;
 }
diff --git a/src/ops/lora_linear_params.cc b/src/ops/lora_linear_params.cc
index a487d12c80..6400e6a500 100644
--- a/src/ops/lora_linear_params.cc
+++ b/src/ops/lora_linear_params.cc
@@ -50,62 +50,117 @@ std::ostream &operator<<(std::ostream &os, LoraAdamOptimizerConfig const &llc) {
   return os;
 }
 
+// Serialization helpers
+template <typename T>
+void serialize_to_json_file(T const &obj, fs::path const &filepath) {
+  json j = obj;
+  std::ofstream file(filepath);
+  file << j.dump(4);
+}
+
+template <typename T>
+std::unique_ptr<T> deserialize_from_json_file(fs::path const &filepath) {
+  std::ifstream file(filepath);
+  json j;
+  file >> j;
+  return std::make_unique<T>(j.get<T>());
+}
+
+template void
+    serialize_to_json_file<LoraLinearConfig>(LoraLinearConfig const &obj,
+                                             fs::path const &filepath);
+template void serialize_to_json_file<LoraSGDOptimizerConfig>(
+    LoraSGDOptimizerConfig const &obj, fs::path const &filepath);
+template void serialize_to_json_file<LoraAdamOptimizerConfig>(
+    LoraAdamOptimizerConfig const &obj, fs::path const &filepath);
+template std::unique_ptr<LoraLinearConfig>
+    deserialize_from_json_file<LoraLinearConfig>(fs::path const &filepath);
+template std::unique_ptr<LoraSGDOptimizerConfig>
+    deserialize_from_json_file<LoraSGDOptimizerConfig>(
+        fs::path const &filepath);
+template std::unique_ptr<LoraAdamOptimizerConfig>
+    deserialize_from_json_file<LoraAdamOptimizerConfig>(
+        fs::path const &filepath);
+
 // ------------------ LoRA configs -------------------
 // ---------------------------------------------------
-const LoraLinearConfig LoraLinearConfig::EmptyConfig = LoraLinearConfig();
-
-LoraLinearConfig::LoraLinearConfig()
-    : rank(0), trainable(false), optimizer_config(nullptr), cache_folder(""),
-      peft_model_id(""), lora_alpha(0), lora_dropout(0.0f),
-      load_weights_from_file(false) {}
-
-LoraLinearConfig::LoraLinearConfig(int _rank,
-                                   bool _trainable,
-                                   LoraOptimizerConfig *_optimizer_config)
-    : rank(_rank), trainable(_trainable), optimizer_config(_optimizer_config),
-      cache_folder(""), peft_model_id(""), lora_alpha(0), lora_dropout(0.0f),
-      load_weights_from_file(false) {}
-
-LoraLinearConfig::LoraLinearConfig(std::string const &cache_folder_,
-                                   std::string const &peft_model_id_,
-                                   bool trainable_,
-                                   LoraOptimizerConfig *optimizer_config_)
-    : cache_folder(cache_folder_), peft_model_id(peft_model_id_),
+const LoraLinearConfig LoraLinearConfig::EmptyConfig = LoraLinearConfig("", "");
+
+LoraLinearConfig::LoraLinearConfig(
+    std::string const &cache_folder_,
+    std::string const &peft_model_id_,
+    bool trainable_,
+    LoraOptimizerConfig *optimizer_config_,
+    bool init_lora_weights_,
+    int rank_,
+    float lora_alpha_,
+    float lora_dropout_,
+    std::vector<std::string> const &target_modules_)
+    : cache_folder(cache_folder_), peft_model_id(peft_model_id_), rank(rank_),
+      lora_alpha(lora_alpha_), lora_dropout(lora_dropout_),
       trainable(trainable_), optimizer_config(optimizer_config_),
-      load_weights_from_file(true) {
-  std::string peft_inference_config_file_path =
-      join_path({cache_folder, "configs", peft_model_id, "config.json"});
-  std::ifstream config_file(peft_inference_config_file_path);
-  if (config_file.is_open()) {
-    try {
-      json model_config;
-      config_file >> model_config;
-      rank = model_config["r"];
-      lora_alpha = model_config["lora_alpha"];
-      lora_dropout = model_config["lora_dropout"];
-      for (auto &s : model_config["target_modules"]) {
-        target_modules.push_back(s);
+      init_lora_weights(init_lora_weights_), target_modules(target_modules_) {
+
+  if (peft_model_id.empty()) {
+    return;
+  }
+  assert(!cache_folder.empty() &&
+         "cache_folder must be provided when using PEFT");
+  if (trainable) {
+    assert(optimizer_config != nullptr &&
+           "optimizer_config must be provided when using PEFT");
+  } else {
+    assert(init_lora_weights == false &&
+           "init_lora_weights must be false when LORA not trainable");
+    assert(optimizer_config == nullptr &&
+           "optimizer_config must be nullptr when not trainable");
+  }
+  // if we are not initializing LORA from scratch, load the configs from
+  // existing repository
+  if (!init_lora_weights) {
+    std::string peft_inference_config_file_path =
+        join_path({cache_folder, "configs", peft_model_id, "config.json"});
+    std::ifstream config_file(peft_inference_config_file_path);
+    if (config_file.is_open()) {
+      try {
+        json model_config;
+        config_file >> model_config;
+        rank = model_config["r"];
+        lora_alpha = float(model_config["lora_alpha"]);
+        lora_dropout = model_config["lora_dropout"];
+        for (auto &s : model_config["target_modules"]) {
+          target_modules.push_back(s);
+        }
+      } catch (json::exception const &e) {
+        std::cerr << "Error parsing PEFT config from JSON file: " << e.what()
+                  << std::endl;
+        assert(false);
       }
-    } catch (json::exception const &e) {
-      std::cerr << "Error parsing PEFT config from JSON file: " << e.what()
+    } else {
+      std::cerr << "Error opening JSON file " << peft_inference_config_file_path
                 << std::endl;
       assert(false);
     }
-  } else {
-    std::cerr << "Error opening JSON file " << peft_inference_config_file_path
-              << std::endl;
-    assert(false);
   }
+  assert(rank > 0 && "rank must be greater than 0");
+  assert(lora_alpha > 0.0f && "lora_alpha must be greater than 0.0");
+  assert(lora_dropout >= 0.0f && lora_dropout <= 1.0f &&
+         "lora_dropout must be in [0.0, 1.0]");
+  assert(target_modules.size() > 0 && "target_modules must not be left empty");
 }
 
+// constructor used to support unordered_map
+LoraLinearConfig::LoraLinearConfig() : LoraLinearConfig("", "") {}
+
 bool operator==(LoraLinearConfig const &lhs, LoraLinearConfig const &rhs) {
-  if (lhs.rank == rhs.rank && lhs.optimizer_config == rhs.optimizer_config &&
-      lhs.cache_folder == rhs.cache_folder &&
-      lhs.peft_model_id == rhs.peft_model_id &&
+  if (lhs.cache_folder == rhs.cache_folder &&
+      lhs.peft_model_id == rhs.peft_model_id && lhs.rank == rhs.rank &&
       lhs.lora_alpha == rhs.lora_alpha &&
       lhs.lora_dropout == rhs.lora_dropout &&
       lhs.target_modules.size() == rhs.target_modules.size() &&
-      lhs.load_weights_from_file == rhs.load_weights_from_file) {
+      lhs.trainable == rhs.trainable &&
+      lhs.init_lora_weights == rhs.init_lora_weights &&
+      lhs.optimizer_config == rhs.optimizer_config) {
     for (int i = 0; i < lhs.target_modules.size(); i++) {
       if (lhs.target_modules[i] != rhs.target_modules[i]) {
         return false;
@@ -118,8 +173,20 @@ bool operator==(LoraLinearConfig const &lhs, LoraLinearConfig const &rhs) {
 
 std::ostream &operator<<(std::ostream &os, LoraLinearConfig const &llc) {
   os << "LoraLinearConfig: ";
-  os << "trainable: " << llc.trainable << ", ";
+  os << "cache_folder: " << llc.cache_folder << ", ";
+  os << "peft_model_id: " << llc.peft_model_id << ", ";
   os << "rank: " << llc.rank << ", ";
+  os << "lora_alpha: " << llc.lora_alpha << ", ";
+  os << "lora_dropout: " << llc.lora_dropout << ", ";
+  os << "target_modules: [";
+  for (int i = 0; i < llc.target_modules.size(); i++) {
+    os << llc.target_modules[i];
+    if (i < llc.target_modules.size() - 1) {
+      os << ", ";
+    }
+  }
+  os << "], ";
+  os << "trainable: " << llc.trainable << ", ";
   if (llc.optimizer_config != nullptr) {
     os << "optimizer_config: ";
     if (typeid(*llc.optimizer_config) == typeid(LoraSGDOptimizerConfig)) {
@@ -132,19 +199,7 @@ std::ostream &operator<<(std::ostream &os, LoraLinearConfig const &llc) {
     }
     std::cout << std::endl;
   }
-  os << "cache_folder: " << llc.cache_folder << ", ";
-  os << "peft_model_id: " << llc.peft_model_id << ", ";
-  os << "lora_alpha: " << llc.lora_alpha << ", ";
-  os << "lora_dropout: " << llc.lora_dropout << ", ";
-  os << "target_modules: [";
-  for (int i = 0; i < llc.target_modules.size(); i++) {
-    os << llc.target_modules[i];
-    if (i < llc.target_modules.size() - 1) {
-      os << ", ";
-    }
-  }
-  os << "], ";
-  os << "load_weights_from_file: " << llc.load_weights_from_file << std::endl;
+  os << "init_lora_weights: " << llc.init_lora_weights << std::endl;
   return os;
 }
 
diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc
index 027ca7f5c0..b5d60787af 100644
--- a/src/runtime/batch_config.cc
+++ b/src/runtime/batch_config.cc
@@ -143,6 +143,8 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) {
       os << "    PEFT Model ID: " << bc.requestsInfo[i].peft_model_id
          << std::endl;
       os << "    PEFT bwd: " << bc.requestsInfo[i].peft_bwd << std::endl;
+      os << "    GradientsUpdateMode: "
+         << bc.requestsInfo[i].gradients_update_mode << std::endl;
       os << "    Max sequence length: "
          << bc.requestsInfo[i].max_sequence_length << std::endl;
       os << "    Request completed: " << bc.request_completed[i] << std::endl;
diff --git a/tests/peft_test.sh b/tests/peft_test.sh
index 797d4e242b..675511b550 100755
--- a/tests/peft_test.sh
+++ b/tests/peft_test.sh
@@ -3,7 +3,7 @@
 set -e
 
 # Cd into directory holding this script
-cd "${BASH_SOURCE[0]%/*}"
+cd "${BASH_SOURCE[0]%/*}/.."
 
 # Token to access private huggingface models (e.g. LLAMA-2)
 HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-none}
@@ -12,52 +12,41 @@ if [[ "$HUGGINGFACE_TOKEN" != "none" ]]; then
 fi
 
 # Create test prompt file
-mkdir -p ../inference/prompt
-echo '["Two things are infinite: "]' > ../inference/prompt/peft.json
-echo '["“Two things are infinite: the universe and human stupidity; and I'\''m not sure about the universe.”"]' > ../inference/prompt/peft_dataset.json
+mkdir -p ./inference/prompt
+echo '["Two things are infinite: "]' > ./inference/prompt/peft.json
+echo '["“Two things are infinite: the universe and human stupidity; and I'\''m not sure about the universe.”"]' > ./inference/prompt/peft_dataset.json
 
 
 # Create output folder
-mkdir -p ../inference/output
+mkdir -p ./inference/output
 
 # Enable backtrace in case we run into a segfault or assertion failure
 export LEGION_BACKTRACE=1
 
 # Download test model
-python ../inference/utils/download_peft_model.py goliaro/llama-160m-lora --base_model_name JackFram/llama-160m 
-# if first time, add: --refresh-cache
+python ./inference/utils/download_peft_model.py goliaro/llama-160m-lora --base_model_name JackFram/llama-160m 
 
-# # CPP test
-# ../build/inference/peft/peft \
-#     -ll:gpu 4 -ll:cpu 4 -ll:util 4 \
-#     -tensor-parallelism-degree 4 \
-#     -ll:fsize 8192 -ll:zsize 12000 \
-#     -llm-model JackFram/llama-160m \
-#     -finetuning-dataset ../inference/prompt/peft_dataset.json \
-#     -peft-model goliaro/llama-160m-lora \
-#     --use-full-precision \
-#     --fusion \
-#     -enable-peft
+# Run PEFT in Huggingface to get ground truth tensors
+python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision
 
-# # Python test
-# python ../inference/python/ff_peft.py
+# Python test
+python ./inference/python/ff_peft.py
+# Check alignment
+python ./tests/peft/peft_alignment_test.py
 
-cd ../build
-./inference/peft/peft \
+# C++ test
+./build/inference/peft/peft \
     -ll:gpu 1 -ll:cpu 4 -ll:util 4 \
     -tensor-parallelism-degree 1 \
     -ll:fsize 8192 -ll:zsize 12000 \
     -llm-model JackFram/llama-160m \
-    -finetuning-dataset ../inference/prompt/peft_dataset.json \
+    -finetuning-dataset ./inference/prompt/peft_dataset.json \
     -peft-model goliaro/llama-160m-lora \
     -enable-peft \
     --use-full-precision \
     --inference-debugging
-
-cd ../tests/peft
-python hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision
-
-python peft_alignment_test.py
+# Check alignment
+python ./tests/peft/peft_alignment_test.py
 
 # Print succeess message
 echo ""