Skip to content

Commit

Permalink
start grads
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jul 2, 2024
1 parent c9e70ae commit e8ad3f8
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
13 changes: 13 additions & 0 deletions src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,12 @@ void LoraLinear::inference_task(Task const *task,
assert(false);
}

int rank, num_tokens;
for (auto it = m->model_state.begin(); it != m->model_state.end(); ++it) {
PEFTModelID peft_model_id = it->first;
LoraLinearWeight weight = m->model_state[peft_model_id].weights;
rank = weight.rank;
num_tokens = input.domain.get_volume() / weight.in_dim;
fs::path dst_filepath_weights =
get_dst_folder("weights", m->decoding_step, shard_id) / layername;
std::string filenameA =
Expand Down Expand Up @@ -694,6 +697,16 @@ void LoraLinear::inference_task(Task const *task,
} else {
assert(false);
}
// input activation (intermediate)
filename = dst_filepath.string() + ".low_rank_activation";
assert(num_tokens == 128);
if (output.data_type == DT_FLOAT) {
save_tensor((float*)m->low_rank_activation, rank*num_tokens, filename.c_str());
} else if (output.data_type == DT_HALF) {
save_tensor((half*)m->low_rank_activation, rank*num_tokens, filename.c_str());
} else {
assert(false);
}
m->decoding_step++;
}
}
Expand Down
114 changes: 107 additions & 7 deletions tests/peft/peft_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def get_hf_tensor(hf_tensor_name, tensor_comparison_idx):
return hf_tensor

def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE):
ff_tensor_filename = f"{ff_tensor_name}.{tensor_comparison_idx.ff_tensor_type}_{tensor_comparison_idx.ff_tensor_idx}"
ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else ""
ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else ""
ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}"
ff_tensor_path = os.path.join(ff_fwd_folder, ff_tensor_filename)
if not os.path.isfile(ff_tensor_path):
raise FileNotFoundError(f"File '{ff_tensor_path}' not found")
Expand Down Expand Up @@ -295,13 +297,16 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
compare(hf_tensor, ff_tensor, label=f"LoRA_A {i} input", tolerance=1e-4)
torch.testing.assert_close(hf_down_proj_in, hf_tensor, rtol=1.3e-6, atol=1e-5)

# LoRA intermediate (HF only)
# LoRA intermediate
input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0)
output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="low_rank_activation", hf_tensor_idx=0, ff_tensor_idx=None)
hf_lora_A_out = get_hf_tensor(hf_tensor_name, output_comparison)
hf_tensor_name = f"layers.{i}.mlp.down_proj.lora_B.default"
hf_lora_B_in = get_hf_tensor(hf_tensor_name, input_comparison)
torch.testing.assert_close(hf_lora_A_out, hf_lora_B_in, rtol=1.3e-6, atol=1e-5)
ff_tensor_name = f"layers.{i}.layers.{i}.mlp.down_proj.lora"
ff_lora_A_out = get_ff_tensor(ff_tensor_name, output_comparison, hf_lora_A_out.shape, tp_type=TPType.TO_REDUCE)
compare(hf_lora_A_out, ff_lora_A_out, label=f"LoRA_A {i} output", tolerance=1e-4)

# LoRA_B
hf_tensor_name = f"layers.{i}.mlp.down_proj.lora_B.default"
Expand Down Expand Up @@ -594,8 +599,103 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
# if i > 1:
# compare(hf_tensor, input_layernorm1, label=f"Input layernorm {i} gradient input", tolerance=1e-5)

def check_step(self, step_idx):
raise NotImplementedError()
def check_step(self, step_idx=0):
hf_weight_folder = os.path.join(hf_path, "weights", f"step_{step_idx}")
ff_weight_folder = os.path.join(ff_path, "weights", f"step_{step_idx}", "shard_0")
def convert_hf_filename_to_ff(hf_filename):
assert hf_filename.startswith("layers.")
layernum = hf_filename.split("layers.")[1].split(".")[0]
f_version = f"layers.{layernum}."
f_version += hf_filename.replace(".base_layer", "").replace(".default", "")
# lora in HuggingFace is split into A and B operators, in FF we use a single operator.
f_version = f_version.replace("lora_A", "lora.weight_A").replace("lora_B", "lora.weight_B")
return f_version
def get_hf_tensor(hf_tensor_name):
hf_tensor_path = os.path.join(hf_weight_folder, hf_tensor_name)

if not os.path.isfile(hf_tensor_path):
raise FileNotFoundError(f"File '{hf_tensor_path}' not found")
hf_tensor = torch.load(hf_tensor_path, map_location='cpu')
return hf_tensor
def get_ff_tensor(ff_tensor_name, hf_shape, tp_type=TPType.REPLICATE, pre=False):
ff_tensor_path = os.path.join(ff_weight_folder, ff_tensor_name)
if pre:
ff_tensor_path = ff_tensor_path.replace(f"step_{step_idx}", f"step_{step_idx}_pre")
if not os.path.isfile(ff_tensor_path):
raise FileNotFoundError(f"File '{ff_tensor_path}' not found")

ff_shape = list(hf_shape)[::-1]
if tp_type == TPType.PARTITION:
ff_shape[0] //= self.tp_degree

ff_tensors = [load_ff_tensor(ff_tensor_path.replace("shard_0", f"shard_{tp_idx}"), ff_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
# if replicate, check that they are identical
if tp_type == TPType.REPLICATE:
assert(are_np_arrays_identical(ff_tensors))
ff_tensor = ff_tensors[0]
# if partition, concatenate along the partition dimension
elif tp_type == TPType.PARTITION:
ff_tensor = np.concatenate(ff_tensors, axis=0)
# if to_reduce, sum along the partition dimension
elif tp_type == TPType.TO_REDUCE:
ff_tensor = np.sum(ff_tensors, axis=0)
else:
ff_tensor = ff_tensors[0]
ff_tensor = torch.from_numpy(ff_tensor)
return ff_tensor
def compare(hf_tensor, ff_tensor, label="", tolerance=1e-5):
ff_tensor = ff_tensor.to(hf_tensor.dtype)
hf_tensor = hf_tensor.T
try:
# torch.testing.assert_close(hf_tensor, ff_tensor, rtol=rtol, atol=tolerance)
if not np.allclose(hf_tensor.numpy(), ff_tensor.numpy(), atol=tolerance):
mismatches = np.where(~np.isclose(hf_tensor, ff_tensor, atol=tolerance))[0]
print(f"Pct mismatch {label}: {100.0*(np.prod(mismatches.shape) / ff_tensor.numel()):.3f}%")
assert(np.prod(mismatches.shape) <= .05 * ff_tensor.numel())
except Exception as e:
print(f"Error in comparison {label}:\n{e}\n")
print("HF tensor:")
print(hf_tensor.squeeze())
print("FF tensor:")
print(ff_tensor.squeeze())
raise e
print(f"-- optimizer pass {step_idx}--")

for i in range(self.num_layers-1, -1, -1):
# LoRA_B gradient
hf_gradient_name = f"layers.{i}.mlp.down_proj.lora_B.default.gradient"
hf_gradient = get_hf_tensor(hf_gradient_name)
hf_original_weight_name = f"layers.{i}.mlp.down_proj.lora_B.default.weight_original"
hf_original_weight = get_hf_tensor(hf_original_weight_name)
hf_finetuned_weight_name = f"layers.{i}.mlp.down_proj.lora_B.default.weight_finetuned"
hf_finetuned_weight = get_hf_tensor(hf_finetuned_weight_name)
torch.testing.assert_close(hf_gradient, hf_original_weight-hf_finetuned_weight, rtol=1.3e-6, atol=1e-5)
print("ok")
ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name)
ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.TO_REDUCE)

ff_out_gradient_name = f"layers.{i}.layers.{i}.mlp.down_proj.lora.output_gradient_0"
ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0")
ff_bwd_folder = os.path.join(ff_path, "bwd", f"step_{step_idx}", "shard_0")
ff_out_gradient = load_ff_tensor(os.path.join(ff_bwd_folder, ff_out_gradient_name), [self.hidden_size, 128])[:,:self.num_tokens]
ff_out_gradient = torch.from_numpy(ff_out_gradient)
print("Output gradient shape: ", ff_out_gradient.shape)
ff_low_rank_activation = f"layers.{i}.layers.{i}.mlp.down_proj.lora.low_rank_activation"
ff_low_rank_activation = load_ff_tensor(os.path.join(ff_fwd_folder, ff_low_rank_activation), [16, 128])[:,:self.num_tokens]
ff_low_rank_activation = torch.from_numpy(ff_low_rank_activation)
print("Low rank activation shape: ", ff_low_rank_activation.shape)
simulated_weight_grad = ff_low_rank_activation @ ff_out_gradient.T
print("Simulated weight grad shape: ", simulated_weight_grad.shape)
# print(simulated_weight_grad)
# print(ff_gradient)

compare(hf_gradient, simulated_weight_grad, label=f"LoRA_B {i} simulated gradient", tolerance=1e-5)
compare(hf_gradient, ff_gradient, label=f"LoRA_B {i} gradient", tolerance=1e-5)

# LoRA_A gradient
hf_gradient_name = f"layers.{i}.mlp.down_proj.lora_A.default.gradient"
ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name)

def check_weights_alignment(num_layers=12):
print("-- Weights alignment --")
Expand Down Expand Up @@ -922,8 +1022,8 @@ def check_llama_bwd_pass(hf_config, tot_num_layers = 12, step_idx=0):
llama_alignment = LllamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree)
# llama_alignment.check_weights_alignment()
llama_alignment.check_fwd_pass()
llama_alignment.check_bwd_pass()
# hf_config = get_model_config(args.model_name)
# llama_alignment.check_bwd_pass()
llama_alignment.check_step()

# check_weights_alignment(num_layers=args.num_layers)
# n_steps=5
Expand Down

0 comments on commit e8ad3f8

Please sign in to comment.