Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 13, 2024
1 parent 1c2b541 commit 4263cfb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 3 additions & 1 deletion prover/tactic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def __init__(

def initialize(self) -> None:
self.hf_gen.initialize()
self.retriever = PremiseRetriever.load_hf(self.ret_path, self.device)
self.retriever = PremiseRetriever.load_hf(
self.ret_path, self.max_inp_seq_len, self.device
)
self.retriever.load_corpus(self.indexed_corpus_path)

async def generate(
Expand Down
2 changes: 1 addition & 1 deletion retrieval/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
device = torch.device("cpu")
else:
device = torch.device("cuda")
model = PremiseRetriever.load_hf(args.ckpt_path, device, max_seq_len=2048)
model = PremiseRetriever.load_hf(args.ckpt_path, 2048, device)
model.load_corpus(args.corpus_path)
model.reindex_corpus(batch_size=args.batch_size)

Expand Down
4 changes: 1 addition & 3 deletions retrieval/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ def load(cls, ckpt_path: str, device, freeze: bool) -> "PremiseRetriever":

@classmethod
def load_hf(
cls, ckpt_path: str, device: int, dtype=None, max_seq_len: Optional[int] = None
cls, ckpt_path: str, max_seq_len: int, device: int, dtype=None
) -> "PremiseRetriever":
if max_seq_len is None:
max_seq_len = 999999999999
model = PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100).to(device).eval()
if dtype is not None:
return model.to(dtype)
Expand Down

0 comments on commit 4263cfb

Please sign in to comment.