Skip to content

Commit

Permalink
Better pyo3 bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 6, 2024
1 parent 3118696 commit 5eb9285
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 1 deletion.
10 changes: 10 additions & 0 deletions yomikomi-pyo3/py_src/yomikomi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Generated content DO NOT EDIT
from .. import yomikomi

audio = yomikomi.audio
jsonl = yomikomi.jsonl
stream = yomikomi.stream
warc = yomikomi.warc
JsonFilter = yomikomi.JsonFilter
StreamIter = yomikomi.StreamIter
YkIterable = yomikomi.YkIterable
111 changes: 111 additions & 0 deletions yomikomi-pyo3/py_src/yomikomi/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike

@staticmethod
def audio(file):
"""
Returns a stream that iterates over the pcm data in an audio file.
"""
pass

@staticmethod
def jsonl(file, *, offset=0, field=None, filters=..., include_if_missing=False):
"""
Returns a stream that iterates over the text contained in a specific field of a jsonl file.
"""
pass

@staticmethod
def stream(iterable, *, field=None):
"""
Returns a stream based on a python iterator. The iterator can either return a whole dictionary
or if `field` is specified single values which will be embedded in a dictionary with a single
entry named as per the field argument.
"""
pass

@staticmethod
def warc(file):
"""
Returns a stream that iterates over the documents contained in a warc file.
"""
pass

class JsonFilter:
@staticmethod
def eq(field, value, *, include_if_missing=False):
""" """
pass

@staticmethod
def greater(field, value, *, include_if_missing=False):
""" """
pass

@staticmethod
def greater_eq(field, value, *, include_if_missing=False):
""" """
pass

@staticmethod
def lower(field, value, *, include_if_missing=False):
""" """
pass

@staticmethod
def lower_eq(field, value, *, include_if_missing=False):
""" """
pass

@staticmethod
def neq(field, value, *, include_if_missing=False):
""" """
pass

class StreamIter:
pass

class YkIterable:
def batch(self, batch_size, *, return_incomplete_last_batch=False):
""" """
pass

def enumerate(self, field):
""" """
pass

def filter(self, f, *, field=None):
"""
Filters a stream, the elements are kept if the provided function `f` returns `True` on
them, otherwise they are discarded. If `field` is specified, the function `f` is only
passed the value associated to this field rather than a whole dictionary.
"""
pass

def filter_key(self, key, *, remove=False):
""" """
pass

def key_transform(self, f, *, field):
""" """
pass

def map(self, f):
""" """
pass

def prefetch(self, *, num_threads, buffer_size=None):
""" """
pass

def sliding_window(self, window_size, *, stride=None, field=..., overlap_over_samples=False):
""" """
pass

def tokenize(self, path, *, in_field=..., out_field=None, report_bpb=True, include_bos=True, include_eos=False):
"""
Loads a sentencepiece tokenizer, and use it to tokenize the field passed as an argument of
this function.
"""
pass
8 changes: 7 additions & 1 deletion yomikomi-pyo3/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[build-system]
requires = ["maturin>=1.4,<2.0"]
requires = [
"maturin>=1.4,<2.0",
"cmake>=3.12",
]
build-backend = "maturin"

[project]
Expand All @@ -13,4 +16,7 @@ classifiers = [
dynamic = ["version"]

[tool.maturin]
python-source = "py_src"
module-name = "yomikomi.yomikomi"
bindings = 'pyo3'
features = ["pyo3/extension-module"]
238 changes: 238 additions & 0 deletions yomikomi-pyo3/stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
import argparse
import inspect
import os
from typing import Optional
import black
from pathlib import Path
import re


INDENT = " " * 4
GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
RETURN_TYPE_MARKER = "&RETURNS&: "
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")


def do_indent(text: Optional[str], indent: str):
if text is None:
return ""
return text.replace("\n", f"\n{indent}")


def function(obj, indent: str, text_signature: str = None):
if text_signature is None:
text_signature = obj.__text_signature__

text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
doc_string = obj.__doc__
if doc_string is None:
doc_string = ""

# Check if we have a return type annotation in the docstring
return_type = None
doc_lines = doc_string.split("\n")
if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
# Extract the return type and remove it from the docstring
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip()
doc_string = "\n".join(doc_lines[:-1])

string = ""
if return_type:
string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n"
else:
string += f"{indent}def {obj.__name__}{text_signature}:\n"
indent += INDENT
string += f'{indent}"""\n'
string += f"{indent}{do_indent(doc_string, indent)}\n"
string += f'{indent}"""\n'
string += f"{indent}pass\n"
string += "\n"
string += "\n"
return string


def member_sort(member):
if inspect.isclass(member):
value = 10 + len(inspect.getmro(member))
else:
value = 1
return value


def fn_predicate(obj):
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
if value:
return obj.__text_signature__ and not obj.__name__.startswith("_")
if inspect.isgetsetdescriptor(obj):
return not obj.__name__.startswith("_")
return False


def get_module_members(module):
members = [
member
for name, member in inspect.getmembers(module)
if not name.startswith("_") and not inspect.ismodule(member)
]
members.sort(key=member_sort)
return members


def pyi_file(obj, indent=""):
string = ""
if inspect.ismodule(obj):
string += GENERATED_COMMENT
string += TYPING
members = get_module_members(obj)
for member in members:
string += pyi_file(member, indent)

elif inspect.isclass(obj):
indent += INDENT
mro = inspect.getmro(obj)
if len(mro) > 2:
inherit = f"({mro[1].__name__})"
else:
inherit = ""
string += f"class {obj.__name__}{inherit}:\n"

body = ""
if obj.__doc__:
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'

fns = inspect.getmembers(obj, fn_predicate)

# Init
if obj.__text_signature__:
body += f"{indent}def __init__{obj.__text_signature__}:\n"
body += f"{indent+INDENT}pass\n"
body += "\n"


for name, fn in fns:
body += pyi_file(fn, indent=indent)

if not body:
body += f"{indent}pass\n"

string += body
string += "\n\n"

elif inspect.isbuiltin(obj):
string += f"{indent}@staticmethod\n"
string += function(obj, indent)

elif inspect.ismethoddescriptor(obj):
string += function(obj, indent)

elif inspect.isgetsetdescriptor(obj):
# TODO it would be interesting to add the setter maybe ?
string += f"{indent}@property\n"
string += function(obj, indent, text_signature="(self)")

elif obj.__class__.__name__ == "DType":
string += f"class {str(obj).lower()}(DType):\n"
string += f"{indent+INDENT}pass\n"
else:
raise Exception(f"Object {obj} is not supported")
return string


def py_file(module, origin):
members = get_module_members(module)

string = GENERATED_COMMENT
string += f"from .. import {origin}\n"
string += "\n"
for member in members:
if hasattr(member, "__name__"):
name = member.__name__
else:
name = str(member)
string += f"{name} = {origin}.{name}\n"
return string


def do_black(content, is_pyi):
mode = black.Mode(
target_versions={black.TargetVersion.PY35},
line_length=119,
is_pyi=is_pyi,
string_normalization=True,
)
try:
return black.format_file_contents(content, fast=True, mode=mode)
except black.NothingChanged:
return content


def write(module, directory, origin, check=False):
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]

filename = os.path.join(directory, "__init__.pyi")
pyi_content = pyi_file(module)
pyi_content = do_black(pyi_content, is_pyi=True)
os.makedirs(directory, exist_ok=True)
if check:
with open(filename, "r") as f:
data = f.read()
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(pyi_content)

filename = os.path.join(directory, "__init__.py")
py_content = py_file(module, origin)
py_content = do_black(py_content, is_pyi=False)
os.makedirs(directory, exist_ok=True)

is_auto = False
if not os.path.exists(filename):
is_auto = True
else:
with open(filename, "r") as f:
line = f.readline()
if line == GENERATED_COMMENT:
is_auto = True

if is_auto:
if check:
with open(filename, "r") as f:
data = f.read()
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(py_content)

for name, submodule in submodules:
write(submodule, os.path.join(directory, name), f"{name}", check=check)


def extract_additional_types(module):
additional_types = {}
for name, member in inspect.getmembers(module):
if inspect.isclass(member):
if hasattr(member, "__name__"):
name = member.__name__
else:
name = str(member)
if name not in additional_types:
additional_types[name] = member
return additional_types


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")

args = parser.parse_args()

cwd = Path.cwd()
directory = "py_src/yomikomi/"

import yomikomi
write(yomikomi.yomikomi, directory, "yomikomi", check=args.check)

0 comments on commit 5eb9285

Please sign in to comment.