Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Unsort #124

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions src/textual_fastdatatable/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from contextlib import suppress
from pathlib import Path
Expand All @@ -15,6 +16,7 @@
Union,
)

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.lib as pal
Expand Down Expand Up @@ -124,8 +126,12 @@ def row_count(self) -> int:
pass

@property
@abstractmethod
def column_count(self) -> int:
return len(self.columns)
"""
The number of columns
"""
pass

@property
@abstractmethod
Expand Down Expand Up @@ -165,7 +171,7 @@ def append_column(self, label: str, default: Any | None = None) -> int:
@abstractmethod
def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
"""
Returns new row indicies
Returns new row indices
"""
pass

Expand All @@ -176,7 +182,7 @@ def drop_row(self, row_index: int) -> None:
@abstractmethod
def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
"""
Raises IndexError if bad indicies
Raises IndexError if bad indices
"""

@abstractmethod
Expand Down Expand Up @@ -213,6 +219,13 @@ def __init__(self, data: pa.Table, max_rows: int | None = None) -> None:
self.data = data.slice(offset=0, length=max_rows)
else:
self.data = data

self._index_col = f"_index_{str(uuid.uuid4())[-12:]}"
self.data = (
self.data.append_column(self._index_col, pa.array(np.arange(self.row_count)))
.select([self._index_col] + self.data.column_names)
)

self._console = Console()
self._column_content_widths: list[int] = []

Expand Down Expand Up @@ -280,11 +293,11 @@ def row_count(self) -> int:

@property
def column_count(self) -> int:
return self.data.num_columns
return self.data.num_columns - 1

@property
def columns(self) -> Sequence[str]:
return self.data.column_names
return self.data.column_names[1:]

@property
def column_content_widths(self) -> list[int]:
Expand All @@ -297,14 +310,14 @@ def column_content_widths(self) -> list[int]:
return self._column_content_widths

def get_row_at(self, index: int) -> Sequence[Any]:
row: Dict[str, Any] = self.data.slice(index, length=1).to_pylist()[0]
row: Dict[str, Any] = self.data.select(self.columns).slice(index, length=1).to_pylist()[0]
return list(row.values())

def get_column_at(self, column_index: int) -> Sequence[Any]:
return self.data[column_index].to_pylist()
return self.data[column_index+1].to_pylist()

def get_cell_at(self, row_index: int, column_index: int) -> Any:
return self.data[column_index][row_index].as_py()
return self.data[column_index+1][row_index].as_py()

def append_column(self, label: str, default: Any | None = None) -> int:
"""
Expand All @@ -319,11 +332,11 @@ def append_column(self, label: str, default: Any | None = None) -> int:
self.data = self.data.append_column(label, arr)
if self._column_content_widths:
self._column_content_widths.append(measure_width(default, self._console))
return self.data.num_columns - 1
return self.column_count - 1

def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
rows = list(records)
indicies = list(range(self.row_count, self.row_count + len(rows)))
indices = list(range(self.row_count, self.row_count + len(rows)))
records_with_headers = [self.data.column_names, *rows]
pydict = self._pydict_from_records(records_with_headers, has_header=True)
old_rows = self.data.to_batches()
Expand All @@ -333,7 +346,7 @@ def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
)
self.data = pa.Table.from_batches([*old_rows, new_rows])
self._reset_content_widths()
return indicies
return indices

def drop_row(self, row_index: int) -> None:
if row_index < 0 or row_index >= self.row_count:
Expand All @@ -342,7 +355,6 @@ def drop_row(self, row_index: int) -> None:
below = self.data.slice(row_index + 1).to_batches()
self.data = pa.Table.from_batches([*above, *below])
self._reset_content_widths()
pass

def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
column = self.data.column(column_index)
Expand All @@ -361,13 +373,14 @@ def update_cell(self, row_index: int, column_index: int, value: Any) -> None:
)

def sort(
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str | None
) -> None:
"""
by: str sorts table by the data in the column with that name (asc).
by: list[tuple] sorts the table by the named column(s) with the directions
indicated.
"""
by = by if by else self._index_col
self.data = self.data.sort_by(by)

def _reset_content_widths(self) -> None:
Expand Down Expand Up @@ -475,6 +488,13 @@ def __init__(self, data: pl.DataFrame, max_rows: int | None = None) -> None:
self.data = data.slice(offset=0, length=max_rows)
else:
self.data = data

self._index_col = f"_index_{str(uuid.uuid4())[-12:]}"
self.data = (
self.data.with_columns(pl.Series(self._index_col, np.arange(self.row_count)))
.select([self._index_col] + self.data.columns)
)

self._console = Console()
self._column_content_widths: list[int] = []

Expand All @@ -488,41 +508,41 @@ def source_row_count(self) -> int:

@property
def row_count(self) -> int:
return len(self.data)
return self.data.height

@property
def column_count(self) -> int:
return len(self.data.columns)
return self.data.width - 1

@property
def columns(self) -> Sequence[str]:
return self.data.columns
return self.data.columns[1:]

def get_row_at(self, index: int) -> Sequence[Any]:
if index < 0 or index >= len(self.data):
if index < 0 or index >= self.row_count:
raise IndexError(
f"Cannot get row={index} in table with {len(self.data)} rows and {len(self.data.columns)} cols"
f"Cannot get row={index} in table with {self.row_count} rows and {self.column_count} cols"
)
return list(self.data.slice(index, length=1).to_dicts()[0].values())
return list(self.data.select(self.columns).slice(index, length=1).to_dicts()[0].values())

def get_column_at(self, column_index: int) -> Sequence[Any]:
if column_index < 0 or column_index >= len(self.data.columns):
if column_index < 0 or column_index >= self.column_count:
raise IndexError(
f"Cannot get column={column_index} in table with {len(self.data)} rows and {len(self.data.columns)} cols"
f"Cannot get column={column_index} in table with {self.row_count} rows and {self.column_count} cols"
)
return list(self.data.to_series(column_index))
return list(self.data.to_series(column_index+1))

def get_cell_at(self, row_index: int, column_index: int) -> Any:
if (
row_index >= len(self.data)
row_index >= self.row_count
or row_index < 0
or column_index < 0
or column_index >= len(self.data.columns)
or column_index >= self.column_count
):
raise IndexError(
f"Cannot get cell at row={row_index} col={column_index} in table with {len(self.data)} rows and {len(self.data.columns)} cols"
f"Cannot get cell at row={row_index} col={column_index} in table with {self.row_count} rows and {self.column_count} cols"
)
return self.data.to_series(column_index)[row_index]
return self.data.to_series(column_index+1)[row_index]

def drop_row(self, row_index: int) -> None:
if row_index < 0 or row_index >= self.row_count:
Expand All @@ -536,10 +556,10 @@ def append_rows(self, records: Iterable[Iterable[Any]]) -> list[int]:
rows_to_add = pl.from_dicts(
[dict(zip(self.data.columns, row)) for row in records]
)
indicies = list(range(self.row_count, self.row_count + len(rows_to_add)))
indices = list(range(self.row_count, self.row_count + len(rows_to_add)))
self.data = pl.concat([self.data, rows_to_add])
self._reset_content_widths()
return indicies
return indices

def append_column(self, label: str, default: Any | None = None) -> int:
"""
Expand All @@ -554,7 +574,7 @@ def append_column(self, label: str, default: Any | None = None) -> int:
self._column_content_widths.append(
measure_width(default, self._console)
)
return len(self.data.columns) - 1
return self.column_count - 1

def _reset_content_widths(self) -> None:
self._column_content_widths = []
Expand Down Expand Up @@ -621,13 +641,15 @@ def _measure(self, arr: pl.Series) -> int:
return width

def sort(
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str
self, by: list[tuple[str, Literal["ascending", "descending"]]] | str | None
) -> None:
"""
by: str sorts table by the data in the column with that name (asc).
by: list[tuple] sorts the table by the named column(s) with the directions
indicated.
"""
by = by if by else self._index_col

if isinstance(by, str):
cols = [by]
typs = [False]
Expand Down
8 changes: 4 additions & 4 deletions src/textual_fastdatatable/data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,9 +1473,9 @@ def add_rows(self, rows: Iterable[Iterable[Any]]) -> list[int]:
"""
if self.backend is None:
self.backend = create_backend(list(rows))
indicies = list(range(self.row_count))
indices = list(range(self.row_count))
else:
indicies = self.backend.append_rows(rows)
indices = self.backend.append_rows(rows)
self._require_update_dimensions = True
self.cursor_coordinate = self.cursor_coordinate

Expand All @@ -1490,7 +1490,7 @@ def add_rows(self, rows: Iterable[Iterable[Any]]) -> list[int]:

self._update_count += 1
self.check_idle()
return indicies
return indices

def remove_row(self, row_index: int) -> None:
"""Remove a row (identified by a key) from the DataTable.
Expand Down Expand Up @@ -2354,7 +2354,7 @@ def _get_fixed_offset(self) -> Spacing:

def sort(
self,
by: list[tuple[str, Literal["ascending", "descending"]]] | str,
by: list[tuple[str, Literal["ascending", "descending"]]] | str | None,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys.

Expand Down