From 314079c75fe2f4649b7c783878db2fc3fe683be4 Mon Sep 17 00:00:00 2001
From: rlaphoenix <rlaphoenix@pm.me>
Date: Tue, 21 Feb 2023 16:09:35 +0000
Subject: [PATCH] Pass save path to DRM decrypt functions directly

This is required in segmented scenarios when multi-threaded where the same `track.path` would be get and set from possibly at the same time. It's also just better logically to do it this way.
---
 devine/commands/dl.py         |  5 +++--
 devine/core/drm/clearkey.py   | 18 +++++++++---------
 devine/core/drm/widevine.py   | 19 +++++++++++--------
 devine/core/manifests/dash.py | 16 ++++++----------
 devine/core/manifests/hls.py  |  8 +++-----
 5 files changed, 32 insertions(+), 34 deletions(-)

diff --git a/devine/commands/dl.py b/devine/commands/dl.py
index 4c5bc8b..ce7661c 100644
--- a/devine/commands/dl.py
+++ b/devine/commands/dl.py
@@ -630,7 +630,6 @@ class dl:
                 service.session.headers,
                 proxy if track.needs_proxy else None
             ))
-            track.path = save_path
 
             if not track.drm and isinstance(track, (Video, Audio)):
                 try:
@@ -644,7 +643,9 @@ class dl:
                 if isinstance(drm, Widevine):
                     # license and grab content keys
                     prepare_drm(drm)
-                drm.decrypt(track)
+                drm.decrypt(save_path)
+                track.drm = None
+                track.path = save_path
                 if callable(track.OnDecrypted):
                     track.OnDecrypted(track)
         else:
diff --git a/devine/core/drm/clearkey.py b/devine/core/drm/clearkey.py
index 5803bde..190f0da 100644
--- a/devine/core/drm/clearkey.py
+++ b/devine/core/drm/clearkey.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+import shutil
+from pathlib import Path
 from typing import Optional, Union
 from urllib.parse import urljoin
 
@@ -7,8 +9,6 @@ import requests
 from Cryptodome.Cipher import AES
 from m3u8.model import Key
 
-from devine.core.constants import TrackT
-
 
 class ClearKey:
     """AES Clear Key DRM System."""
@@ -34,20 +34,20 @@ class ClearKey:
         self.key: bytes = key
         self.iv: bytes = iv
 
-    def decrypt(self, track: TrackT) -> None:
+    def decrypt(self, path: Path) -> None:
         """Decrypt a Track with AES Clear Key DRM."""
-        if not track.path or not track.path.exists():
-            raise ValueError("Tried to decrypt a track that has not yet been downloaded.")
+        if not path or not path.exists():
+            raise ValueError("Tried to decrypt a file that does not exist.")
 
         decrypted = AES. \
             new(self.key, AES.MODE_CBC, self.iv). \
-            decrypt(track.path.read_bytes())
+            decrypt(path.read_bytes())
 
-        decrypted_path = track.path.with_suffix(f".decrypted{track.path.suffix}")
+        decrypted_path = path.with_suffix(f".decrypted{path.suffix}")
         decrypted_path.write_bytes(decrypted)
 
-        track.swap(decrypted_path)
-        track.drm = None
+        path.unlink()
+        shutil.move(decrypted_path, path)
 
     @classmethod
     def from_m3u_key(cls, m3u_key: Key, proxy: Optional[str] = None) -> ClearKey:
diff --git a/devine/core/drm/widevine.py b/devine/core/drm/widevine.py
index 7897c46..12cf3f3 100644
--- a/devine/core/drm/widevine.py
+++ b/devine/core/drm/widevine.py
@@ -1,8 +1,10 @@
 from __future__ import annotations
 
 import base64
+import shutil
 import subprocess
 import sys
+from pathlib import Path
 from typing import Any, Callable, Optional, Union
 from uuid import UUID
 
@@ -14,7 +16,7 @@ from pywidevine.pssh import PSSH
 from requests import Session
 
 from devine.core.config import config
-from devine.core.constants import AnyTrack, TrackT
+from devine.core.constants import AnyTrack
 from devine.core.utilities import get_binary_path, get_boxes
 from devine.core.utils.subprocess import ffprobe
 
@@ -212,7 +214,7 @@ class Widevine:
             finally:
                 cdm.close(session_id)
 
-    def decrypt(self, track: TrackT) -> None:
+    def decrypt(self, path: Path) -> None:
         """
         Decrypt a Track with Widevine DRM.
         Raises:
@@ -227,15 +229,15 @@ class Widevine:
         executable = get_binary_path("shaka-packager", f"packager-{platform}", f"packager-{platform}-x64")
         if not executable:
             raise EnvironmentError("Shaka Packager executable not found but is required.")
-        if not track.path or not track.path.exists():
-            raise ValueError("Tried to decrypt a track that has not yet been downloaded.")
+        if not path or not path.exists():
+            raise ValueError("Tried to decrypt a file that does not exist.")
 
-        decrypted_path = track.path.with_suffix(f".decrypted{track.path.suffix}")
+        decrypted_path = path.with_suffix(f".decrypted{path.suffix}")
         config.directories.temp.mkdir(parents=True, exist_ok=True)
         try:
             subprocess.check_call([
                 executable,
-                f"input={track.path},stream=0,output={decrypted_path}",
+                f"input={path},stream=0,output={decrypted_path}",
                 "--enable_raw_key_decryption", "--keys",
                 ",".join([
                     *[
@@ -252,8 +254,9 @@ class Widevine:
             ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
         except subprocess.CalledProcessError as e:
             raise subprocess.SubprocessError(f"Failed to Decrypt! Shaka Packager Error: {e}")
-        track.swap(decrypted_path)
-        track.drm = None
+
+        path.unlink()
+        shutil.move(decrypted_path, path)
 
     class Exceptions:
         class PSSHNotFound(Exception):
diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py
index 7d93cf3..5bb64ac 100644
--- a/devine/core/manifests/dash.py
+++ b/devine/core/manifests/dash.py
@@ -414,12 +414,9 @@ class DASH:
                     session.headers,
                     proxy
                 ))
-                # TODO: More like `segment.path`, but this will do for now
-                #       Needed for the drm.decrypt() call couple lines down
-                track.path = segment_save_path
 
                 if isinstance(track, Audio) or init_data:
-                    with open(track.path, "rb+") as f:
+                    with open(segment_save_path, "rb+") as f:
                         segment_data = f.read()
                         if isinstance(track, Audio):
                             # fix audio decryption on ATVP by fixing the sample description index
@@ -437,7 +434,8 @@ class DASH:
 
                 if drm:
                     # TODO: What if the manifest does not mention DRM, but has DRM
-                    drm.decrypt(track)
+                    drm.decrypt(segment_save_path)
+                    track.drm = None
                     if callable(track.OnDecrypted):
                         track.OnDecrypted(track)
         elif segment_list is not None:
@@ -485,12 +483,9 @@ class DASH:
                         proxy,
                         byte_range=segment_url.get("mediaRange")
                     ))
-                    # TODO: More like `segment.path`, but this will do for now
-                    #       Needed for the drm.decrypt() call couple lines down
-                    track.path = segment_save_path
 
                     if isinstance(track, Audio) or init_data:
-                        with open(track.path, "rb+") as f:
+                        with open(segment_save_path, "rb+") as f:
                             segment_data = f.read()
                             if isinstance(track, Audio):
                                 # fix audio decryption on ATVP by fixing the sample description index
@@ -508,7 +503,8 @@ class DASH:
 
                     if drm:
                         # TODO: What if the manifest does not mention DRM, but has DRM
-                        drm.decrypt(track)
+                        drm.decrypt(segment_save_path)
+                        track.drm = None
                         if callable(track.OnDecrypted):
                             track.OnDecrypted(track)
         elif segment_base is not None or base_url:
diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py
index 8698d6e..a518de7 100644
--- a/devine/core/manifests/hls.py
+++ b/devine/core/manifests/hls.py
@@ -269,12 +269,9 @@ class HLS:
                 session.headers,
                 proxy
             ))
-            # TODO: More like `segment.path`, but this will do for now
-            #       Needed for the drm.decrypt() call couple lines down
-            track.path = segment_save_path
 
             if isinstance(track, Audio) or init_data:
-                with open(track.path, "rb+") as f:
+                with open(segment_save_path, "rb+") as f:
                     segment_data = f.read()
                     if isinstance(track, Audio):
                         # fix audio decryption on ATVP by fixing the sample description index
@@ -291,7 +288,8 @@ class HLS:
                         f.write(segment_data)
 
             if last_segment_key[0]:
-                last_segment_key[0].decrypt(track)
+                last_segment_key[0].decrypt(segment_save_path)
+                track.drm = None
                 if callable(track.OnDecrypted):
                     track.OnDecrypted(track)