Skip to content

Commit

Permalink
support plain data as nmrnet input
Browse files Browse the repository at this point in the history
  • Loading branch information
link89 committed Jan 22, 2025
1 parent cfff4c3 commit ac39507
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions ai2_kit/algorithm/uninmr/infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import Namespace
from io import StringIO
from itertools import product
from scipy.spatial import distance_matrix

Expand Down Expand Up @@ -285,9 +286,11 @@ def predict(model: UniMatModel, dataloader: DataLoader,

def predict_cli(model_path: str, dict_path: str, saved_dir: str,
selected_atom: str, nmr_type: str, use_cuda=False, cuda_device_id=None,
smiles: str='', data_file: str = '', data_content: str = '', ase_format = None):
smiles: str='', data_file: str = '', data: str = '', format = None):
"""
Command line interface for NMRNet prediction
Command line interface for NMRNet prediction.
You can provide input data with one of `data_file`, `data` or `smiles`.
:param model_path: path to the model checkpoint, e.g 'model.pt'
:param dict_path: path to the dictionary file, e.g 'dict.txt'
Expand All @@ -297,10 +300,14 @@ def predict_cli(model_path: str, dict_path: str, saved_dir: str,
:param use_cuda: whether to use GPU for prediction, default is False
:param cuda_device_id: GPU device id, default is None, required when use_cuda is True
:param data_file: path to the input data file, which should be able to parse by ASE
:param data: input data string, default is '', you can provide data directly
:param smiles: SMILES string for prediction, default is ''
:param format: format of the input data file, default is None, you can find the supported format in ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html
"""
if data_file:
atoms = ase.io.read(data_file, index=0, format=ase_format) # type: ignore
atoms = ase.io.read(data_file, index=0, format=format) # type: ignore
elif data:
atoms = ase.io.read(StringIO(data), index=0, format=format) # type: ignore
elif smiles:
atoms = smiles_to_atoms(smiles)
else:
Expand Down

0 comments on commit ac39507

Please sign in to comment.