Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

csv_overrides: Add stronger type hints #1851

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 68 additions & 65 deletions cursorless-talon/src/csv_overrides.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import csv
from collections.abc import Container
from collections.abc import Container, Mapping
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Callable, Optional

from talon import Context, Module, actions, app, fs

Expand Down Expand Up @@ -33,15 +34,16 @@

def init_csv_and_watch_changes(
filename: str,
default_values: dict[str, dict[str, str]],
extra_ignored_values: Optional[list[str]] = None,
default_values: Mapping[str, Mapping[str, str]],
extra_ignored_values: Container[str] = (),
*,
allow_unknown_values: bool = False,
default_list_name: Optional[str] = None,
headers: list[str] = [SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER],
ctx: Context = Context(),
no_update_file: bool = False,
pluralize_lists: Optional[list[str]] = [],
):
pluralize_lists: Container[str] = (),
) -> Callable[[], None]:
"""
Initialize a cursorless settings csv, creating it if necessary, and watch
for changes to the csv. Talon lists will be generated based on the keys of
Expand All @@ -60,27 +62,25 @@ def init_csv_and_watch_changes(
`cursorles-settings` dir
default_values (dict[str, dict]): The default values for the lists to
be customized in the given csv
extra_ignored_values list[str]: Don't throw an exception if any of
extra_ignored_values: Don't throw an exception if any of
these appear as values; just ignore them and don't add them to any list
allow_unknown_values bool: If unknown values appear, just put them in the list
default_list_name Optional[str]: If unknown values are allowed, put any
unknown values in this list
no_update_file Optional[bool]: Set this to `TRUE` to indicate that we should
no_update_file: Set this to `TRUE` to indicate that we should
not update the csv. This is used generally in case there was an issue coming up with the default set of values so we don't want to persist those to disk
pluralize_lists: Create plural version of given lists
"""
if extra_ignored_values is None:
extra_ignored_values = []

file_path = get_full_path(filename)
super_default_values = get_super_values(default_values)
assert allow_unknown_values == (default_list_name is not None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we have dependent types here! It would've been too annoying to check both allow_unknown_values and default_list_name in update_dicts to appease the type checker, but I'm also too lazy to update the callers of init_csv_and_watch_changes right now.


file_path.parent.mkdir(parents=True, exist_ok=True)

check_for_duplicates(filename, default_values)
create_default_vocabulary_dicts(default_values, pluralize_lists)

def on_watch(path, flags):
def on_watch(path: str, flags) -> None:
if file_path.match(path):
current_values, has_errors = read_file(
file_path,
Expand All @@ -93,7 +93,6 @@ def on_watch(path, flags):
default_values,
current_values,
extra_ignored_values,
allow_unknown_values,
default_list_name,
pluralize_lists,
ctx,
Expand All @@ -114,7 +113,6 @@ def on_watch(path, flags):
default_values,
current_values,
extra_ignored_values,
allow_unknown_values,
default_list_name,
pluralize_lists,
ctx,
Expand All @@ -126,35 +124,44 @@ def on_watch(path, flags):
default_values,
super_default_values,
extra_ignored_values,
allow_unknown_values,
default_list_name,
pluralize_lists,
ctx,
)

def unsubscribe():
def unsubscribe() -> None:
fs.unwatch(str(file_path.parent), on_watch)

return unsubscribe


def check_for_duplicates(filename, default_values):
@dataclass
class ListValue:
key: str
value: str
#: The list name.
list: str


def check_for_duplicates(
filename: str, default_values: Mapping[str, Mapping[str, str]]
):
results_map = {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see what type this is supposed to be... because I don't think any items are ever assigned to this 😕

I recall this bit was in one of the other functions previously. Think this ended up being broken after a recent refactor.

for list_name, dict in default_values.items():
for key, value in dict.items():
for list_name, values in default_values.items():
for key, value in values.items():
if value in results_map:
existing_list_name = results_map[value]["list"]
warning = f"WARNING ({filename}): Value `{value}` duplicated between lists '{existing_list_name}' and '{list_name}'"
print(warning)
app.notify(warning)


def is_removed(value: str):
def is_removed(value: str) -> bool:
return value.startswith("-")


def create_default_vocabulary_dicts(
default_values: dict[str, dict], pluralize_lists: list[str]
default_values: Mapping[str, Mapping[str, str]], pluralize_lists: Container[str]
):
default_values_updated = {}
for key, value in default_values.items():
Expand All @@ -169,41 +176,38 @@ def create_default_vocabulary_dicts(


def update_dicts(
default_values: dict[str, dict],
current_values: dict,
extra_ignored_values: list[str],
allow_unknown_values: bool,
default_values: Mapping[str, Mapping[str, str]],
current_values: dict[str, str],
extra_ignored_values: Container[str],
default_list_name: Optional[str],
pluralize_lists: list[str],
pluralize_lists: Container[str],
ctx: Context,
):
) -> None:
# Create map with all default values
results_map = {}
for list_name, dict in default_values.items():
for key, value in dict.items():
results_map[value] = {"key": key, "value": value, "list": list_name}
results_map: dict[str, ListValue] = {}
for list_name, values in default_values.items():
for key, value in values.items():
results_map[value] = ListValue(key=key, value=value, list=list_name)

# Update result with current values
for key, value in current_values.items():
try:
results_map[value]["key"] = key
results_map[value].key = key
except KeyError:
if value in extra_ignored_values:
pass
elif allow_unknown_values:
results_map[value] = {
"key": key,
"value": value,
"list": default_list_name,
}
elif default_list_name is not None:
results_map[value] = ListValue(
key=key, value=value, list=default_list_name
)
else:
raise

# Convert result map back to result list
results = {res["list"]: {} for res in results_map.values()}
results: dict[str, dict[str, str]] = {res.list: {} for res in results_map.values()}
for obj in results_map.values():
value = obj["value"]
key = obj["key"]
value = obj.value
key = obj.key
if not is_removed(key):
for k in key.split("|"):
if value == "pasteFromClipboard" and k.endswith(" to"):
Expand All @@ -214,33 +218,33 @@ def update_dicts(
# cursorless before this change would have "paste to" as
# their spoken form and so would need to say "paste to to".
k = k[:-3]
results[obj["list"]][k.strip()] = value
results[obj.list][k.strip()] = value

# Assign result to talon context list
assign_lists_to_context(ctx, results, pluralize_lists)


def assign_lists_to_context(
ctx: Context,
results: dict,
pluralize_lists: list[str],
results: dict[str, dict[str, str]],
pluralize_lists: Container[str],
):
for list_name, dict in results.items():
for list_name, values in results.items():
list_singular_name = get_cursorless_list_name(list_name)
ctx.lists[list_singular_name] = dict
ctx.lists[list_singular_name] = values
if list_name in pluralize_lists:
list_plural_name = f"{list_singular_name}_plural"
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in dict.items()}
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in values.items()}


def update_file(
path: Path,
headers: list[str],
default_values: dict[str, str],
extra_ignored_values: list[str],
extra_ignored_values: Container[str],
allow_unknown_values: bool,
no_update_file: bool,
):
) -> dict[str, str]:
current_values, has_errors = read_file(
path,
headers,
Expand All @@ -250,7 +254,7 @@ def update_file(
)
current_identifiers = current_values.values()

missing = {}
missing: dict[str, str] = {}
for key, value in default_values.items():
if value not in current_identifiers:
missing[key] = value
Expand All @@ -263,16 +267,17 @@ def update_file(
)
else:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
missing_items = sorted(missing.items())
lines = [
f"# {timestamp} - New entries automatically added by cursorless",
*[create_line(key, missing[key]) for key in sorted(missing)],
*[create_line(key, value) for key, value in missing_items],
]
with open(path, "a") as f:
f.write("\n\n" + "\n".join(lines))

print(f"New cursorless features added to {path.name}")
for key in sorted(missing):
print(f"{key}: {missing[key]}")
for key, value in missing_items:
print(f"{key}: {value}")
print(
"See release notes for more info: "
"https://github.com/cursorless-dev/cursorless/blob/main/CHANGELOG.md"
Expand All @@ -282,18 +287,18 @@ def update_file(
return current_values


def create_line(*cells: str):
def create_line(*cells: str) -> str:
return ", ".join(cells)


def create_file(path: Path, headers: list[str], default_values: dict):
lines = [create_line(key, default_values[key]) for key in sorted(default_values)]
def create_file(path: Path, headers: list[str], default_values: dict[str, str]) -> None:
lines = [create_line(key, value) for key, value in sorted(default_values.items())]
lines.insert(0, create_line(*headers))
lines.append("")
path.write_text("\n".join(lines))


def csv_error(path: Path, index: int, message: str, value: str):
def csv_error(path: Path, index: int, message: str, value: str) -> None:
"""Check that an expected condition is true

Note that we try to continue reading in this case so cursorless doesn't get bricked
Expand All @@ -310,16 +315,15 @@ def read_file(
path: Path,
headers: list[str],
default_identifiers: Container[str],
extra_ignored_values: list[str],
extra_ignored_values: Container[str],
allow_unknown_values: bool,
):
) -> tuple[dict[str, str], bool]:
with open(path) as csv_file:
# Use `skipinitialspace` to allow spaces before quote. `, "a,b"`
csv_reader = csv.reader(csv_file, skipinitialspace=True)
rows = list(csv_reader)

result = {}
used_identifiers = []
result: dict[str, str] = {}
has_errors = False
seen_headers = False

Expand Down Expand Up @@ -359,21 +363,20 @@ def read_file(
csv_error(path, i, "Unknown identifier", value)
continue

if value in used_identifiers:
if value in result:
has_errors = True
csv_error(path, i, "Duplicate identifier", value)
continue

result[key] = value
used_identifiers.append(value)

if has_errors:
app.notify("Cursorless settings error; see log")

return result, has_errors


def get_full_path(filename: str):
def get_full_path(filename: str) -> Path:
if not filename.endswith(".csv"):
filename = f"{filename}.csv"

Expand All @@ -386,7 +389,7 @@ def get_full_path(filename: str):
return (settings_directory / filename).resolve()


def get_super_values(values: dict[str, dict[str, str]]):
def get_super_values(values: Mapping[str, Mapping[str, str]]) -> dict[str, str]:
result: dict[str, str] = {}
for value_dict in values.values():
result.update(value_dict)
Expand Down