From 383e7d9647720c5ea8b1c039ee094ee02e9f651d Mon Sep 17 00:00:00 2001
From: rlaphoenix <rlaphoenix@pm.me>
Date: Tue, 28 Feb 2023 16:54:52 +0000
Subject: [PATCH] Add full support for CTRL+C on HLS and DASH

---
 devine/commands/dl.py         |  2 +
 devine/core/manifests/dash.py | 87 +++++++++++++++++++--------------
 devine/core/manifests/hls.py  | 91 ++++++++++++++++++++---------------
 3 files changed, 104 insertions(+), 76 deletions(-)

diff --git a/devine/commands/dl.py b/devine/commands/dl.py
index ecf052c..f01533b 100644
--- a/devine/commands/dl.py
+++ b/devine/commands/dl.py
@@ -736,6 +736,7 @@ class dl:
             HLS.download_track(
                 track=track,
                 save_dir=save_dir,
+                stop_event=self.DL_POOL_STOP,
                 progress=progress,
                 session=service.session,
                 proxy=proxy,
@@ -745,6 +746,7 @@ class dl:
             DASH.download_track(
                 track=track,
                 save_dir=save_dir,
+                stop_event=self.DL_POOL_STOP,
                 progress=progress,
                 session=service.session,
                 proxy=proxy,
diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py
index 86ebde7..57edcfe 100644
--- a/devine/core/manifests/dash.py
+++ b/devine/core/manifests/dash.py
@@ -273,6 +273,7 @@ class DASH:
     def download_track(
         track: AnyTrack,
         save_dir: Path,
+        stop_event: Event,
         progress: partial,
         session: Optional[Session] = None,
         proxy: Optional[str] = None,
@@ -445,10 +446,8 @@ class DASH:
                         raise ValueError("license_widevine func must be supplied to use Widevine DRM")
                     license_widevine(drm)
 
-            state_event = Event()
-
             def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
-                if state_event.is_set():
+                if stop_event.is_set():
                     return 0
 
                 segment_save_path = (save_dir / filename).with_suffix(".mp4")
@@ -509,43 +508,57 @@ class DASH:
             last_speed_refresh = time.time()
 
             with ThreadPoolExecutor(max_workers=16) as pool:
-                try:
-                    finished_threads = 0
-                    for download in futures.as_completed((
-                        pool.submit(
-                            download_segment,
-                            filename=str(i).zfill(len(str(len(segments)))),
-                            segment=segment
-                        )
-                        for i, segment in enumerate(segments)
-                    )):
-                        finished_threads += 1
-                        e = download.exception()
-                        if e:
-                            state_event.set()
-                            traceback.print_exception(e)
-                            log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
-                            sys.exit(1)
-                        else:
-                            progress(advance=1)
+                finished_threads = 0
+                has_stopped = False
+                has_failed = False
+                for download in futures.as_completed((
+                    pool.submit(
+                        download_segment,
+                        filename=str(i).zfill(len(str(len(segments)))),
+                        segment=segment
+                    )
+                    for i, segment in enumerate(segments)
+                )):
+                    finished_threads += 1
+                    try:
+                        download_size = download.result()
+                    except KeyboardInterrupt:
+                        stop_event.set()
+                        if not has_stopped:
+                            has_stopped = True
+                            progress(downloaded="[orange]STOPPING")
+                    except Exception as e:
+                        stop_event.set()
+                        if has_stopped:
+                            # we don't care because we were stopping anyway
+                            continue
+                        if not has_failed:
+                            has_failed = True
+                            progress(downloaded="[red]FAILING")
+                        traceback.print_exception(e)
+                        log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
+                    else:
+                        if stop_event.is_set():
+                            # skipped
+                            continue
+                        progress(advance=1)
 
-                            now = time.time()
-                            time_since = now - last_speed_refresh
+                        now = time.time()
+                        time_since = now - last_speed_refresh
 
-                            download_size = download.result()
-                            if download_size:  # no size == skipped dl
-                                download_sizes.append(download_size)
+                        if download_size:  # no size == skipped dl
+                            download_sizes.append(download_size)
 
-                            if time_since > 5 or finished_threads == len(segments):
-                                data_size = sum(download_sizes)
-                                download_speed = data_size / time_since
-                                progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
-                                last_speed_refresh = now
-                                download_sizes.clear()
-                except KeyboardInterrupt:
-                    state_event.set()
-                    log.info("Received Keyboard Interrupt, stopping...")
-                    return
+                        if time_since > 5 or finished_threads == len(segments):
+                            data_size = sum(download_sizes)
+                            download_speed = data_size / time_since
+                            progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
+                            last_speed_refresh = now
+                            download_sizes.clear()
+                if has_failed:
+                    progress(downloaded="[red]FAILED")
+                if has_stopped:
+                    progress(downloaded="[yellow]STOPPED")
 
     @staticmethod
     def get_language(*options: Any) -> Optional[Language]:
diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py
index 0c0b1ba..b38beee 100644
--- a/devine/core/manifests/hls.py
+++ b/devine/core/manifests/hls.py
@@ -182,6 +182,7 @@ class HLS:
     def download_track(
         track: AnyTrack,
         save_dir: Path,
+        stop_event: Event,
         progress: partial,
         session: Optional[Session] = None,
         proxy: Optional[str] = None,
@@ -212,10 +213,8 @@ class HLS:
             log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
             sys.exit(1)
 
-        state_event = Event()
-
         def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
-            if state_event.is_set():
+            if stop_event.is_set():
                 return 0
 
             segment_save_path = (save_dir / filename).with_suffix(".mp4")
@@ -347,45 +346,59 @@ class HLS:
         last_speed_refresh = time.time()
 
         with ThreadPoolExecutor(max_workers=16) as pool:
-            try:
-                finished_threads = 0
-                for download in futures.as_completed((
-                    pool.submit(
-                        download_segment,
-                        filename=str(i).zfill(len(str(len(master.segments)))),
-                        segment=segment,
-                        init_data=init_data,
-                        segment_key=segment_key
-                    )
-                    for i, segment in enumerate(master.segments)
-                )):
-                    finished_threads += 1
-                    e = download.exception()
-                    if e:
-                        state_event.set()
-                        traceback.print_exception(e)
-                        log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
-                        sys.exit(1)
-                    else:
-                        progress(advance=1)
+            finished_threads = 0
+            has_stopped = False
+            has_failed = False
+            for download in futures.as_completed((
+                pool.submit(
+                    download_segment,
+                    filename=str(i).zfill(len(str(len(master.segments)))),
+                    segment=segment,
+                    init_data=init_data,
+                    segment_key=segment_key
+                )
+                for i, segment in enumerate(master.segments)
+            )):
+                finished_threads += 1
+                try:
+                    download_size = download.result()
+                except KeyboardInterrupt:
+                    stop_event.set()
+                    if not has_stopped:
+                        has_stopped = True
+                        progress(downloaded="[orange]STOPPING")
+                except Exception as e:
+                    stop_event.set()
+                    if has_stopped:
+                        # we don't care because we were stopping anyway
+                        continue
+                    if not has_failed:
+                        has_failed = True
+                        progress(downloaded="[red]FAILING")
+                    traceback.print_exception(e)
+                    log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
+                else:
+                    if stop_event.is_set():
+                        # skipped
+                        continue
+                    progress(advance=1)
 
-                        now = time.time()
-                        time_since = now - last_speed_refresh
+                    now = time.time()
+                    time_since = now - last_speed_refresh
 
-                        download_size = download.result()
-                        if download_size:  # no size == skipped dl
-                            download_sizes.append(download_size)
+                    if download_size:  # no size == skipped dl
+                        download_sizes.append(download_size)
 
-                        if time_since > 5 or finished_threads == len(master.segments):
-                            data_size = sum(download_sizes)
-                            download_speed = data_size / time_since
-                            progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
-                            last_speed_refresh = now
-                            download_sizes.clear()
-            except KeyboardInterrupt:
-                state_event.set()
-                log.info("Received Keyboard Interrupt, stopping...")
-                return
+                    if time_since > 5 or finished_threads == len(master.segments):
+                        data_size = sum(download_sizes)
+                        download_speed = data_size / time_since
+                        progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
+                        last_speed_refresh = now
+                        download_sizes.clear()
+            if has_failed:
+                progress(downloaded="[red]FAILED")
+            if has_stopped:
+                progress(downloaded="[yellow]STOPPED")
 
     @staticmethod
     def get_drm(