Skip to content

Commit

Permalink
🩹 (io) handle stdin/out str vs bytest properly
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwoerpel committed Nov 15, 2023
1 parent 98d820c commit 241cac0
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions ftmq/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ class SmartHandler:
def __init__(
self,
uri: Any,
sys_io: Literal[sys.stdin.buffer, sys.stdout.buffer] | None = sys.stdin,
sys_io: Literal[sys.stdin, sys.stdout] | None = sys.stdin,
*args,
**kwargs,
) -> None:
if not uri:
raise ValueError("Missing uri")
self.uri = str(uri)
self.sys_io = sys_io
self.args = args
kwargs["mode"] = kwargs.get("mode", "rb")
if kwargs["mode"].endswith("b"):
sys_io = sys_io.buffer
self.sys_io = sys_io
self.kwargs = kwargs
self.is_buffer = self.uri == "-"
self.handler = None
Expand All @@ -58,7 +60,7 @@ def __exit__(self, *args, **kwargs) -> None:
@contextlib.contextmanager
def smart_open(
uri: Any,
sys_io: Literal[sys.stdin.buffer, sys.stdout.buffer] | None = sys.stdin,
sys_io: Literal[sys.stdin, sys.stdout] | None = sys.stdin,
*args,
**kwargs,
):
Expand All @@ -71,7 +73,7 @@ def smart_open(

def _smart_stream(uri, *args, **kwargs) -> Any:
kwargs["mode"] = kwargs.get("mode", "rb")
with smart_open(uri, sys.stdin.buffer, *args, **kwargs) as fh:
with smart_open(uri, sys.stdin, *args, **kwargs) as fh:
while line := fh.readline():
yield line

Expand All @@ -82,13 +84,13 @@ def smart_read(uri, *args, **kwargs) -> Any:
if stream:
return _smart_stream(uri, *args, **kwargs)

with smart_open(uri, sys.stdin.buffer, *args, **kwargs) as fh:
with smart_open(uri, sys.stdin, *args, **kwargs) as fh:
return fh.read()


def smart_write(uri, content: bytes | str, *args, **kwargs) -> Any:
kwargs["mode"] = kwargs.get("mode", "wb")
with smart_open(uri, sys.stdout.buffer, *args, **kwargs) as fh:
with smart_open(uri, sys.stdout, *args, **kwargs) as fh:
fh.write(content)


Expand Down Expand Up @@ -151,7 +153,7 @@ def smart_write_proxies(
log.info("Writing proxy %d ..." % ix)
return ix

with smart_open(uri, sys.stdout.buffer, mode=mode) as fh:
with smart_open(uri, sys.stdout, mode=mode) as fh:
for proxy in proxies:
ix += 1
if serialize:
Expand Down

0 comments on commit 241cac0

Please sign in to comment.