diff --git a/cursorless-talon/src/csv_overrides.py b/cursorless-talon/src/csv_overrides.py index 4c51b7c3a4..3e359bbc29 100644 --- a/cursorless-talon/src/csv_overrides.py +++ b/cursorless-talon/src/csv_overrides.py @@ -1,8 +1,9 @@ import csv -from collections.abc import Container +from collections.abc import Container, Mapping +from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Callable, Optional from talon import Context, Module, actions, app, fs @@ -33,15 +34,16 @@ def init_csv_and_watch_changes( filename: str, - default_values: dict[str, dict[str, str]], - extra_ignored_values: Optional[list[str]] = None, + default_values: Mapping[str, Mapping[str, str]], + extra_ignored_values: Container[str] = (), + *, allow_unknown_values: bool = False, default_list_name: Optional[str] = None, headers: list[str] = [SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER], ctx: Context = Context(), no_update_file: bool = False, - pluralize_lists: Optional[list[str]] = [], -): + pluralize_lists: Container[str] = (), +) -> Callable[[], None]: """ Initialize a cursorless settings csv, creating it if necessary, and watch for changes to the csv. Talon lists will be generated based on the keys of @@ -60,27 +62,25 @@ def init_csv_and_watch_changes( `cursorles-settings` dir default_values (dict[str, dict]): The default values for the lists to be customized in the given csv - extra_ignored_values list[str]: Don't throw an exception if any of + extra_ignored_values: Don't throw an exception if any of these appear as values; just ignore them and don't add them to any list allow_unknown_values bool: If unknown values appear, just put them in the list default_list_name Optional[str]: If unknown values are allowed, put any unknown values in this list - no_update_file Optional[bool]: Set this to `TRUE` to indicate that we should + no_update_file: Set this to `TRUE` to indicate that we should not update the csv. This is used generally in case there was an issue coming up with the default set of values so we don't want to persist those to disk pluralize_lists: Create plural version of given lists """ - if extra_ignored_values is None: - extra_ignored_values = [] - file_path = get_full_path(filename) super_default_values = get_super_values(default_values) + assert allow_unknown_values == (default_list_name is not None) file_path.parent.mkdir(parents=True, exist_ok=True) check_for_duplicates(filename, default_values) create_default_vocabulary_dicts(default_values, pluralize_lists) - def on_watch(path, flags): + def on_watch(path: str, flags) -> None: if file_path.match(path): current_values, has_errors = read_file( file_path, @@ -93,7 +93,6 @@ def on_watch(path, flags): default_values, current_values, extra_ignored_values, - allow_unknown_values, default_list_name, pluralize_lists, ctx, @@ -114,7 +113,6 @@ def on_watch(path, flags): default_values, current_values, extra_ignored_values, - allow_unknown_values, default_list_name, pluralize_lists, ctx, @@ -126,22 +124,31 @@ def on_watch(path, flags): default_values, super_default_values, extra_ignored_values, - allow_unknown_values, default_list_name, pluralize_lists, ctx, ) - def unsubscribe(): + def unsubscribe() -> None: fs.unwatch(str(file_path.parent), on_watch) return unsubscribe -def check_for_duplicates(filename, default_values): +@dataclass +class ListValue: + key: str + value: str + #: The list name. + list: str + + +def check_for_duplicates( + filename: str, default_values: Mapping[str, Mapping[str, str]] +): results_map = {} - for list_name, dict in default_values.items(): - for key, value in dict.items(): + for list_name, values in default_values.items(): + for key, value in values.items(): if value in results_map: existing_list_name = results_map[value]["list"] warning = f"WARNING ({filename}): Value `{value}` duplicated between lists '{existing_list_name}' and '{list_name}'" @@ -149,12 +156,12 @@ def check_for_duplicates(filename, default_values): app.notify(warning) -def is_removed(value: str): +def is_removed(value: str) -> bool: return value.startswith("-") def create_default_vocabulary_dicts( - default_values: dict[str, dict], pluralize_lists: list[str] + default_values: Mapping[str, Mapping[str, str]], pluralize_lists: Container[str] ): default_values_updated = {} for key, value in default_values.items(): @@ -169,41 +176,38 @@ def create_default_vocabulary_dicts( def update_dicts( - default_values: dict[str, dict], - current_values: dict, - extra_ignored_values: list[str], - allow_unknown_values: bool, + default_values: Mapping[str, Mapping[str, str]], + current_values: dict[str, str], + extra_ignored_values: Container[str], default_list_name: Optional[str], - pluralize_lists: list[str], + pluralize_lists: Container[str], ctx: Context, -): +) -> None: # Create map with all default values - results_map = {} - for list_name, dict in default_values.items(): - for key, value in dict.items(): - results_map[value] = {"key": key, "value": value, "list": list_name} + results_map: dict[str, ListValue] = {} + for list_name, values in default_values.items(): + for key, value in values.items(): + results_map[value] = ListValue(key=key, value=value, list=list_name) # Update result with current values for key, value in current_values.items(): try: - results_map[value]["key"] = key + results_map[value].key = key except KeyError: if value in extra_ignored_values: pass - elif allow_unknown_values: - results_map[value] = { - "key": key, - "value": value, - "list": default_list_name, - } + elif default_list_name is not None: + results_map[value] = ListValue( + key=key, value=value, list=default_list_name + ) else: raise # Convert result map back to result list - results = {res["list"]: {} for res in results_map.values()} + results: dict[str, dict[str, str]] = {res.list: {} for res in results_map.values()} for obj in results_map.values(): - value = obj["value"] - key = obj["key"] + value = obj.value + key = obj.key if not is_removed(key): for k in key.split("|"): if value == "pasteFromClipboard" and k.endswith(" to"): @@ -214,7 +218,7 @@ def update_dicts( # cursorless before this change would have "paste to" as # their spoken form and so would need to say "paste to to". k = k[:-3] - results[obj["list"]][k.strip()] = value + results[obj.list][k.strip()] = value # Assign result to talon context list assign_lists_to_context(ctx, results, pluralize_lists) @@ -222,25 +226,25 @@ def update_dicts( def assign_lists_to_context( ctx: Context, - results: dict, - pluralize_lists: list[str], + results: dict[str, dict[str, str]], + pluralize_lists: Container[str], ): - for list_name, dict in results.items(): + for list_name, values in results.items(): list_singular_name = get_cursorless_list_name(list_name) - ctx.lists[list_singular_name] = dict + ctx.lists[list_singular_name] = values if list_name in pluralize_lists: list_plural_name = f"{list_singular_name}_plural" - ctx.lists[list_plural_name] = {pluralize(k): v for k, v in dict.items()} + ctx.lists[list_plural_name] = {pluralize(k): v for k, v in values.items()} def update_file( path: Path, headers: list[str], default_values: dict[str, str], - extra_ignored_values: list[str], + extra_ignored_values: Container[str], allow_unknown_values: bool, no_update_file: bool, -): +) -> dict[str, str]: current_values, has_errors = read_file( path, headers, @@ -250,7 +254,7 @@ def update_file( ) current_identifiers = current_values.values() - missing = {} + missing: dict[str, str] = {} for key, value in default_values.items(): if value not in current_identifiers: missing[key] = value @@ -263,16 +267,17 @@ def update_file( ) else: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + missing_items = sorted(missing.items()) lines = [ f"# {timestamp} - New entries automatically added by cursorless", - *[create_line(key, missing[key]) for key in sorted(missing)], + *[create_line(key, value) for key, value in missing_items], ] with open(path, "a") as f: f.write("\n\n" + "\n".join(lines)) print(f"New cursorless features added to {path.name}") - for key in sorted(missing): - print(f"{key}: {missing[key]}") + for key, value in missing_items: + print(f"{key}: {value}") print( "See release notes for more info: " "https://github.com/cursorless-dev/cursorless/blob/main/CHANGELOG.md" @@ -282,18 +287,18 @@ def update_file( return current_values -def create_line(*cells: str): +def create_line(*cells: str) -> str: return ", ".join(cells) -def create_file(path: Path, headers: list[str], default_values: dict): - lines = [create_line(key, default_values[key]) for key in sorted(default_values)] +def create_file(path: Path, headers: list[str], default_values: dict[str, str]) -> None: + lines = [create_line(key, value) for key, value in sorted(default_values.items())] lines.insert(0, create_line(*headers)) lines.append("") path.write_text("\n".join(lines)) -def csv_error(path: Path, index: int, message: str, value: str): +def csv_error(path: Path, index: int, message: str, value: str) -> None: """Check that an expected condition is true Note that we try to continue reading in this case so cursorless doesn't get bricked @@ -310,16 +315,15 @@ def read_file( path: Path, headers: list[str], default_identifiers: Container[str], - extra_ignored_values: list[str], + extra_ignored_values: Container[str], allow_unknown_values: bool, -): +) -> tuple[dict[str, str], bool]: with open(path) as csv_file: # Use `skipinitialspace` to allow spaces before quote. `, "a,b"` csv_reader = csv.reader(csv_file, skipinitialspace=True) rows = list(csv_reader) - result = {} - used_identifiers = [] + result: dict[str, str] = {} has_errors = False seen_headers = False @@ -359,13 +363,12 @@ def read_file( csv_error(path, i, "Unknown identifier", value) continue - if value in used_identifiers: + if value in result: has_errors = True csv_error(path, i, "Duplicate identifier", value) continue result[key] = value - used_identifiers.append(value) if has_errors: app.notify("Cursorless settings error; see log") @@ -373,7 +376,7 @@ def read_file( return result, has_errors -def get_full_path(filename: str): +def get_full_path(filename: str) -> Path: if not filename.endswith(".csv"): filename = f"{filename}.csv" @@ -386,7 +389,7 @@ def get_full_path(filename: str): return (settings_directory / filename).resolve() -def get_super_values(values: dict[str, dict[str, str]]): +def get_super_values(values: Mapping[str, Mapping[str, str]]) -> dict[str, str]: result: dict[str, str] = {} for value_dict in values.values(): result.update(value_dict)