From 1f86775ac97a2039aa12fc639945b929ec2cf154 Mon Sep 17 00:00:00 2001
From: rlaphoenix <rlaphoenix@pm.me>
Date: Thu, 23 Feb 2023 16:35:02 +0000
Subject: [PATCH] Add support for segment downloads with byte-ranges

Adds support for HLS's EXT-X-BYTERANGE and DASH's SegmentBase.
---
 devine/core/manifests/dash.py | 26 ++++++++++++++------
 devine/core/manifests/hls.py  | 46 ++++++++++++++++++++++++++++-------
 2 files changed, 56 insertions(+), 16 deletions(-)

diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py
index 40f09d0..046d0f6 100644
--- a/devine/core/manifests/dash.py
+++ b/devine/core/manifests/dash.py
@@ -455,13 +455,25 @@ class DASH:
 
                 segment_uri, segment_range = segment
 
-                asyncio.run(aria2c(
-                    segment_uri,
-                    segment_save_path,
-                    session.headers,
-                    proxy,
-                    silent=True
-                ))
+                if segment_range:
+                    # aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
+                    r = session.get(
+                        url=segment_uri,
+                        headers={
+                            "Range": f"bytes={segment_range}"
+                        }
+                    )
+                    r.raise_for_status()
+                    segment_save_path.parent.mkdir(parents=True, exist_ok=True)
+                    segment_save_path.write_bytes(res.content)
+                else:
+                    asyncio.run(aria2c(
+                        segment_uri,
+                        segment_save_path,
+                        session.headers,
+                        proxy,
+                        silent=True
+                    ))
 
                 if isinstance(track, Audio) or init_data:
                     with open(segment_save_path, "rb+") as f:
diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py
index 7baca7e..728075d 100644
--- a/devine/core/manifests/hls.py
+++ b/devine/core/manifests/hls.py
@@ -213,7 +213,13 @@ class HLS:
 
         state_event = Event()
 
-        def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
+        def download_segment(
+            filename: str,
+            segment: m3u8.Segment,
+            init_data: Queue,
+            segment_key: Queue,
+            range_offset: Queue
+        ) -> None:
             time.sleep(0.1)
             if state_event.is_set():
                 return
@@ -268,13 +274,32 @@ class HLS:
             if not segment.uri.startswith(segment.base_uri):
                 segment.uri = segment.base_uri + segment.uri
 
-            asyncio.run(aria2c(
-                segment.uri,
-                segment_save_path,
-                session.headers,
-                proxy,
-                silent=True
-            ))
+            if segment.byterange:
+                previous_range_offset = range_offset.get()
+                byte_range_parts = [int(x) for x in segment.byterange.split("@")]
+                if len(byte_range_parts) != 2:
+                    byte_range_parts.append(previous_range_offset)
+                range_length, newest_range_offset = byte_range_parts
+                range_offset.put(newest_range_offset)
+
+                # aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
+                res = session.get(
+                    url=segment.uri,
+                    headers={
+                        "Range": f"bytes={newest_range_offset}-{newest_range_offset+range_length}"
+                    }
+                )
+                res.raise_for_status()
+                segment_save_path.parent.mkdir(parents=True, exist_ok=True)
+                segment_save_path.write_bytes(res.content)
+            else:
+                asyncio.run(aria2c(
+                    segment.uri,
+                    segment_save_path,
+                    session.headers,
+                    proxy,
+                    silent=True
+                ))
 
             if isinstance(track, Audio) or newest_init_data:
                 with open(segment_save_path, "rb+") as f:
@@ -301,6 +326,7 @@ class HLS:
 
         segment_key = Queue(maxsize=1)
         init_data = Queue(maxsize=1)
+        range_offset = Queue(maxsize=1)
 
         if track.drm:
             session_drm = track.drm[0]  # just use the first supported DRM system for now
@@ -315,6 +341,7 @@ class HLS:
         # have data to begin with, or it will be stuck waiting on the first pool forever
         segment_key.put((session_drm, None))
         init_data.put(None)
+        range_offset.put(0)
 
         with tqdm(total=len(master.segments), unit="segments") as pbar:
             with ThreadPoolExecutor(max_workers=16) as pool:
@@ -325,7 +352,8 @@ class HLS:
                             filename=str(i).zfill(len(str(len(master.segments)))),
                             segment=segment,
                             init_data=init_data,
-                            segment_key=segment_key
+                            segment_key=segment_key,
+                            range_offset=range_offset
                         )
                         for i, segment in enumerate(master.segments)
                     )):