Skip to content

Commit

Permalink
Improve API documentation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714103523
  • Loading branch information
iindyk authored and copybara-github committed Jan 13, 2025
1 parent 10fdd0d commit 23a13cc
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 79 deletions.
72 changes: 61 additions & 11 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
import pathlib
import sys
import inspect
import operator
import os

sys.path.insert(0, str(pathlib.Path('..', 'grain').resolve()))

Expand All @@ -25,10 +28,14 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.linkcode',
'myst_nb',
'sphinx_copybutton',
'sphinx_design',
'autoapi.extension',
#'autoapi.extension',
]

templates_path = ['_templates']
Expand All @@ -40,6 +47,9 @@
'tutorials/dataset_basic_tutorial.md',
]

# The main toctree document.
main_doc = 'index'

# Suppress warning in exception basic_data_tutorial
suppress_warnings = [
'misc.highlighting_failure',
Expand Down Expand Up @@ -68,18 +78,58 @@
'navigation_with_keys': True,
}

# -- Extension configuration -------------------------------------------------

# Tell sphinx autodoc how to render type aliases.
autodoc_typehints = "description"
autodoc_typehints_description_target = "all"

# Customize code links via sphinx.ext.linkcode


def linkcode_resolve(domain, info):
import grain

if domain != 'py':
return None
if not info['module']:
return None
if not info['fullname']:
return None
if info['module'].split(".")[0] != 'grain':
return None
try:
mod = sys.modules.get(info['module'])
obj = operator.attrgetter(info['fullname'])(mod)
if isinstance(obj, property):
obj = obj.fget
while hasattr(obj, '__wrapped__'): # decorated functions
obj = obj.__wrapped__
filename = inspect.getsourcefile(obj)
source, linenum = inspect.getsourcelines(obj)
except:
return None
filename = os.path.relpath(filename, start=os.path.dirname(grain.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
return f"https://github.com/google/grain/blob/main/grain/{filename}{lines}"
# Autodoc settings
# Should be relative to the source of the documentation
autoapi_dirs = [
'../grain/_src/core',
'../grain/_src/python',
]

autoapi_ignore = [
'*_test.py',
'testdata/*',
'*/dataset/stats.py',
]
# autoapi_dirs = [
# '../grain/_src/core',
# '../grain/_src/python',
# '../grain/python',
# ]

# autoapi_options = [
# 'members',
# 'imported-members',
# ]

# autoapi_ignore = [
# '*_test.py',
# 'testdata/*',
# '*/dataset/stats.py',
# ]

# -- Myst configurations -------------------------------------------------
myst_enable_extensions = ['colon_fence']
Expand Down
40 changes: 14 additions & 26 deletions grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,24 @@ licenses(["notice"])

exports_files(["LICENSE"])

py_library(
name = "core",
srcs = ["core.py"],
srcs_version = "PY3",
# Implicit build flag
deps = [
"//grain/_src/core:config", # build_cleaner: keep
"//grain/_src/core:constants", # build_cleaner: keep
"//grain/_src/core:sharding", # build_cleaner: keep
],
)

py_library(
name = "python",
srcs = ["python.py"],
srcs = [
"__init__.py",
"_src/__init__.py",
"core/__init__.py",
"python/__init__.py",
"python/experimental.py",
"python/fast_proto.py",
],
srcs_version = "PY3",
# Implicit build flag
visibility = ["//visibility:public"],
deps = [
":core", # build_cleaner: keep
":python_experimental", # build_cleaner: keep
"//grain/_src/core:config", # build_cleaner: keep
"//grain/_src/core:constants", # build_cleaner: keep
"//grain/_src/core:monitoring", # build_cleaner: keep
"//grain/_src/core:sharding", # build_cleaner: keep
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python:checkpoint_handlers", # build_cleaner: keep
"//grain/_src/python:data_loader", # build_cleaner: keep
Expand All @@ -38,18 +33,11 @@ py_library(
"//grain/_src/python:load", # build_cleaner: keep
"//grain/_src/python:operations", # build_cleaner: keep
"//grain/_src/python:options", # build_cleaner: keep
"//grain/_src/python:record",
"//grain/_src/python:samplers", # build_cleaner: keep
],
)

py_library(
name = "python_experimental",
srcs = ["python_experimental.py"],
data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"],
srcs_version = "PY3",
# Implicit build flag
deps = [
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python:shared_memory_array",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"//grain/_src/python/dataset:visualize", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:flatmap", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:interleave", # build_cleaner: keep
Expand Down
Empty file added grain/__init__.py
Empty file.
Empty file added grain/_src/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions grain/core.py → grain/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# pylint: disable=unused-import
# pylint: disable=g-importing-member

from ._src.core.config import config
from ._src.core.constants import (
from grain._src.core.config import config
from grain._src.core.constants import (
DATASET_INDEX,
EPOCH,
INDEX,
Expand All @@ -27,4 +27,4 @@
RECORD_KEY,
SEED,
)
from ._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
47 changes: 29 additions & 18 deletions grain/python.py → grain/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@
"""APIs for Grain Python backend."""

# pylint: disable=g-importing-member
# pylint: disable=g-import-not-at-top
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import

from . import python_experimental as experimental

from ._src.core.config import config
from ._src.core.constants import DATASET_INDEX, EPOCH, INDEX, META_FEATURES, RECORD, RECORD_KEY, SEED
from ._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from ._src.core.transforms import (
from grain._src.core.config import config
from grain._src.core.constants import (
DATASET_INDEX,
EPOCH,
INDEX,
META_FEATURES,
RECORD,
RECORD_KEY,
SEED,
)
from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
from grain._src.core.transforms import (
BatchTransform as Batch,
FilterTransform,
MapTransform,
Expand All @@ -33,43 +40,47 @@
Transformations,
)

from ._src.python.checkpoint_handlers import PyGrainCheckpointHandler
from ._src.python.data_loader import (
from grain._src.python.checkpoint_handlers import PyGrainCheckpointHandler
from grain._src.python.data_loader import (
DataLoader,
PyGrainDatasetIterator,
)
from ._src.python.data_sources import (
from grain._src.python.data_sources import (
ArrayRecordDataSource,
InMemoryDataSource,
RandomAccessDataSource,
RangeDataSource,
)
from ._src.python.dataset.base import DatasetSelectionMap
from ._src.python.dataset.dataset import (
from grain._src.python.dataset.base import DatasetSelectionMap
from grain._src.python.dataset.dataset import (
MapDataset,
IterDataset,
DatasetIterator,
)

from ._src.python.load import load
from ._src.python.operations import (
from grain._src.python.load import load
from grain._src.python.operations import (
BatchOperation,
FilterOperation,
MapOperation,
Operation,
RandomMapOperation,
)
from ._src.python.options import ReadOptions, MultiprocessingOptions
from ._src.python.record import (Record, RecordMetadata)
from ._src.python.samplers import (
from grain._src.python.options import ReadOptions, MultiprocessingOptions
from grain._src.python.record import (Record, RecordMetadata)
from grain._src.python.samplers import (
IndexSampler,
Sampler,
SequentialSampler,
)
from ._src.python.shared_memory_array import SharedMemoryArray
from grain._src.python.shared_memory_array import SharedMemoryArray
from grain.python import experimental, fast_proto

# These are imported only if Orbax is present.
try:
from ._src.python.checkpoint_handlers import PyGrainCheckpointSave, PyGrainCheckpointRestore # pylint: disable=g-import-not-at-top
from grain._src.python.checkpoint_handlers import (
PyGrainCheckpointSave,
PyGrainCheckpointRestore,
)
except ImportError:
pass
38 changes: 22 additions & 16 deletions grain/python_experimental.py → grain/python/experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,36 +17,42 @@
# pylint: disable=g-bad-import-order
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import
# pylint: disable=g-import-not-at-top

from ._src.python.dataset.base import (
from grain._src.python.dataset.base import (
DatasetOptions,
ExecutionTrackingMode,
)
from ._src.python.dataset.dataset import (
from grain._src.python.dataset.dataset import (
apply_transformations,
WithOptionsIterDataset,
)
from ._src.python.dataset.transformations.flatmap import (
from grain._src.python.dataset.transformations.flatmap import (
FlatMapMapDataset,
FlatMapIterDataset,
)
from ._src.python.dataset.transformations.interleave import (
from grain._src.python.dataset.transformations.interleave import (
InterleaveIterDataset,
)
from ._src.python.dataset.transformations.map import RngPool
from ._src.python.dataset.transformations.mix import ConcatenateMapDataset
from ._src.python.dataset.transformations.packing import FirstFitPackIterDataset
from ._src.python.dataset.transformations.prefetch import (
from grain._src.python.dataset.transformations.map import RngPool
from grain._src.python.dataset.transformations.mix import ConcatenateMapDataset
from grain._src.python.dataset.transformations.packing import (
FirstFitPackIterDataset,
)
from grain._src.python.dataset.transformations.prefetch import (
MultiprocessPrefetchIterDataset,
ThreadPrefetchIterDataset,
)
from ._src.python.dataset.transformations.shuffle import WindowShuffleMapDataset
from ._src.python.dataset.transformations.zip import ZipMapDataset
from ._src.core.transforms import (
from grain._src.python.dataset.transformations.shuffle import (
WindowShuffleMapDataset,
)
from grain._src.python.dataset.transformations.zip import ZipMapDataset
from grain._src.core.transforms import (
FlatMapTransform,
MapWithIndexTransform,
)
from ._src.python.experimental.example_packing.packing import PackAndBatchOperation
from ._src.python.experimental.index_shuffle.python.index_shuffle_module import index_shuffle
from grain._src.python.experimental.example_packing.packing import (
PackAndBatchOperation,
)
from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import (
index_shuffle,
)
10 changes: 5 additions & 5 deletions grain/python_proto.py → grain/python/fast_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
"""Proto PyGrain APIs."""

# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
# pylint: disable=wildcard-import

from ._src.python.proto.decode import parse_tf_example
from ._src.python.proto.decode import parse_tf_example_experimental
from ._src.python.proto.encode import make_tf_example
from grain._src.python.proto.decode import (
parse_tf_example,
parse_tf_example_experimental,
)
from grain._src.python.proto.encode import make_tf_example

0 comments on commit 23a13cc

Please sign in to comment.