Add type hints to function arguments

This commit is contained in:
2Shirt 2023-05-28 20:50:38 -07:00
parent 62edaac25a
commit 171cd0019e
Signed by: 2Shirt
GPG key ID: 152FAC923B0E132C
5 changed files with 75 additions and 29 deletions

View file

@ -35,7 +35,10 @@ THRESH_GREAT = 750 * 1024**2
# Functions
def generate_horizontal_graph(rate_list, graph_width=40, oneline=False) -> list[str]:
def generate_horizontal_graph(
rate_list: list[float],
graph_width: int = 40,
oneline: bool = False) -> list[str]:
"""Generate horizontal graph from rate_list, returns list."""
graph = ['', '', '', '']
scale = 8 if oneline else 32
@ -82,7 +85,7 @@ def generate_horizontal_graph(rate_list, graph_width=40, oneline=False) -> list[
return graph
def get_graph_step(rate, scale=16) -> int:
def get_graph_step(rate: float, scale: int = 16) -> int:
"""Get graph step based on rate and scale, returns int."""
rate_in_mb = rate / (1024**2)
step = 0
@ -97,7 +100,10 @@ def get_graph_step(rate, scale=16) -> int:
return step
def merge_rates(rates, graph_width=40) -> list[Union[int, float]]:
def merge_rates(
rates: list[float],
graph_width: int = 40,
) -> list[Union[int, float]]:
"""Merge rates to have entries equal to the width, returns list."""
merged_rates = []
offset = 0
@ -112,7 +118,7 @@ def merge_rates(rates, graph_width=40) -> list[Union[int, float]]:
return merged_rates
def vertical_graph_line(percent, rate, scale=32) -> str:
def vertical_graph_line(percent: float, rate: float, scale: int = 32) -> str:
"""Build colored graph string using thresholds, returns str."""
color_bar = None
color_rate = None

View file

@ -7,13 +7,15 @@ import pathlib
import re
import shutil
from typing import Union
# STATIC VARIABLES
LOG = logging.getLogger(__name__)
# Functions
def case_insensitive_path(path) -> pathlib.Path:
def case_insensitive_path(path: Union[pathlib.Path, str]) -> pathlib.Path:
"""Find path case-insensitively, returns pathlib.Path obj."""
given_path = pathlib.Path(path).resolve()
real_path = None
@ -37,7 +39,8 @@ def case_insensitive_path(path) -> pathlib.Path:
return real_path
def case_insensitive_search(path, item) -> pathlib.Path:
def case_insensitive_search(
path: Union[pathlib.Path, str], item: str) -> pathlib.Path:
"""Search path for item case insensitively, returns pathlib.Path obj."""
path = pathlib.Path(path).resolve()
given_path = path.joinpath(item)
@ -61,7 +64,10 @@ def case_insensitive_search(path, item) -> pathlib.Path:
return real_path
def copy_file(source, dest, overwrite=False) -> None:
def copy_file(
source: Union[pathlib.Path, str],
dest: Union[pathlib.Path, str],
overwrite: bool = False) -> None:
"""Copy file and optionally overwrite the destination."""
source = case_insensitive_path(source)
dest = pathlib.Path(dest).resolve()
@ -72,7 +78,7 @@ def copy_file(source, dest, overwrite=False) -> None:
shutil.copy2(source, dest)
def delete_empty_folders(path) -> None:
def delete_empty_folders(path: Union[pathlib.Path, str]) -> None:
"""Recursively delete all empty folders in path."""
LOG.debug('path: %s', path)
@ -89,7 +95,11 @@ def delete_empty_folders(path) -> None:
pass
def delete_folder(path, force=False, ignore_errors=False) -> None:
def delete_folder(
path: Union[pathlib.Path, str],
force: bool = False,
ignore_errors: bool = False,
) -> None:
"""Delete folder if empty or if forced.
NOTE: Exceptions are not caught by this function,
@ -106,7 +116,11 @@ def delete_folder(path, force=False, ignore_errors=False) -> None:
os.rmdir(path)
def delete_item(path, force=False, ignore_errors=False) -> None:
def delete_item(
path: Union[pathlib.Path, str],
force: bool = False,
ignore_errors: bool = False,
) -> None:
"""Delete file or folder, optionally recursively.
NOTE: Exceptions are not caught by this function,
@ -124,7 +138,11 @@ def delete_item(path, force=False, ignore_errors=False) -> None:
os.remove(path)
def get_path_obj(path, expanduser=True, resolve=True) -> pathlib.Path:
def get_path_obj(
path: Union[pathlib.Path, str],
expanduser: bool = True,
resolve: bool = True,
) -> pathlib.Path:
"""Get based on path, returns pathlib.Path."""
path = pathlib.Path(path)
if expanduser:
@ -134,7 +152,7 @@ def get_path_obj(path, expanduser=True, resolve=True) -> pathlib.Path:
return path
def non_clobber_path(path) -> pathlib.Path:
def non_clobber_path(path: Union[pathlib.Path, str]) -> pathlib.Path:
"""Update path as needed to non-existing path, returns pathlib.Path."""
LOG.debug('path: %s', path)
path = pathlib.Path(path)
@ -163,7 +181,10 @@ def non_clobber_path(path) -> pathlib.Path:
return new_path
def recursive_copy(source, dest, overwrite=False) -> None:
def recursive_copy(
source: Union[pathlib.Path, str],
dest: Union[pathlib.Path, str],
overwrite: bool = False) -> None:
"""Copy source to dest recursively.
NOTE: This uses rsync style source/dest syntax.
@ -213,7 +234,10 @@ def recursive_copy(source, dest, overwrite=False) -> None:
raise FileExistsError(f'Refusing to delete file: {dest}')
def rename_item(path, new_path) -> pathlib.Path:
def rename_item(
path: Union[pathlib.Path, str],
new_path: Union[pathlib.Path, str],
) -> pathlib.Path:
"""Rename item, returns pathlib.Path."""
path = pathlib.Path(path)
return path.rename(new_path)

View file

@ -8,6 +8,8 @@ import pathlib
import shutil
import time
from typing import Union
from wk import cfg
from wk.io import non_clobber_path
@ -39,8 +41,11 @@ def enable_debug_mode() -> None:
def format_log_path(
log_dir=None, log_name=None, timestamp=False,
kit=False, tool=False, append=False) -> pathlib.Path:
log_dir: Union[None, pathlib.Path, str] = None,
log_name: Union[None, str] = None,
timestamp: bool = False,
kit: bool = False, tool: bool = False, append: bool = False,
) -> pathlib.Path:
"""Format path based on args passed, returns pathlib.Path obj."""
log_path = pathlib.Path(
f'{log_dir if log_dir else DEFAULT_LOG_DIR}/'
@ -78,7 +83,7 @@ def get_root_logger_path() -> pathlib.Path:
raise RuntimeError('Log path not found.')
def remove_empty_log(log_path=None) -> None:
def remove_empty_log(log_path: Union[None, pathlib.Path] = None) -> None:
"""Remove log if empty.
NOTE: Under Windows an empty log is 2 bytes long.
@ -101,7 +106,7 @@ def remove_empty_log(log_path=None) -> None:
log_path.unlink()
def start(config=None) -> None:
def start(config: Union[dict[str, str], None] = None) -> None:
"""Configure and start logging using safe defaults."""
log_path = format_log_path(timestamp=os.name != 'nt')
root_logger = logging.getLogger()
@ -124,8 +129,10 @@ def start(config=None) -> None:
def update_log_path(
dest_dir=None, dest_name=None, keep_history=True,
timestamp=True, append=False) -> None:
dest_dir: Union[None, pathlib.Path, str] = None,
dest_name: Union[None, str] = None,
keep_history: bool = True, timestamp: bool = True, append: bool = False,
) -> None:
"""Moves current log file to new path and updates the root logger."""
root_logger = logging.getLogger()
new_path = format_log_path(dest_dir, dest_name, timestamp=timestamp, append=append)

View file

@ -6,6 +6,7 @@ import pathlib
import re
from subprocess import CompletedProcess
from typing import Any, Union
import psutil
@ -25,7 +26,7 @@ REGEX_VALID_IP = re.compile(
# Functions
def connected_to_private_network(raise_on_error=False) -> bool:
def connected_to_private_network(raise_on_error: bool = False) -> bool:
"""Check if connected to a private network, returns bool.
This checks for a valid private IP assigned to this system.
@ -54,7 +55,7 @@ def connected_to_private_network(raise_on_error=False) -> bool:
return connected
def mount_backup_shares(read_write=False) -> list[str]:
def mount_backup_shares(read_write: bool = False) -> list[str]:
"""Mount backup shares using OS specific methods."""
report = []
for name, details in BACKUP_SERVERS.items():
@ -98,7 +99,9 @@ def mount_backup_shares(read_write=False) -> list[str]:
def mount_network_share(
details, mount_point=None, read_write=False) -> CompletedProcess:
details: dict[str, Any],
mount_point: Union[None, pathlib.Path, str] = None,
read_write: bool = False) -> CompletedProcess:
"""Mount network share using OS specific methods."""
cmd = None
address = details['Address']
@ -149,7 +152,7 @@ def mount_network_share(
return run_program(cmd, check=False)
def ping(addr='google.com') -> None:
def ping(addr: str = 'google.com') -> None:
"""Attempt to ping addr."""
cmd = (
'ping',
@ -160,7 +163,7 @@ def ping(addr='google.com') -> None:
run_program(cmd)
def share_is_mounted(details) -> bool:
def share_is_mounted(details: dict[str, Any]) -> bool:
"""Check if dev/share/etc is mounted, returns bool."""
mounted = False
@ -245,7 +248,10 @@ def unmount_backup_shares() -> list[str]:
return report
def unmount_network_share(details=None, mount_point=None) -> CompletedProcess:
def unmount_network_share(
details: Union[dict[str, Any], None] = None,
mount_point: Union[None, pathlib.Path, str] = None,
) -> CompletedProcess:
"""Unmount network share"""
cmd = []

View file

@ -33,7 +33,10 @@ class GenericWarning(Exception):
# Functions
def bytes_to_string(size, decimals=0, use_binary=True) -> str:
def bytes_to_string(
size: Union[float, int],
decimals: int = 0,
use_binary: bool = True) -> str:
"""Convert size into a human-readable format, returns str.
[Doctest]
@ -80,8 +83,8 @@ def sleep(seconds: Union[int, float] = 2) -> None:
time.sleep(seconds)
def string_to_bytes(size, assume_binary=False) -> int:
"""Convert human-readable size str to bytes and return an int."""
def string_to_bytes(size: Union[float, int, str], assume_binary: bool = False) -> int:
"""Convert human-readable size to bytes and return an int."""
LOG.debug('size: %s, assume_binary: %s', size, assume_binary)
scale = 1000
size = str(size)