From 3d794ad659cd1ad56593a074986392ebd7ed4afb Mon Sep 17 00:00:00 2001
From: rlaphoenix <rlaphoenix@pm.me>
Date: Thu, 4 Aug 2022 05:40:59 +0100
Subject: [PATCH] RemoteCdm: Implement /set_service_certificate

---
 pywidevine/remotecdm.py | 37 ++++++++++++++++++++++++++++---------
 1 file changed, 28 insertions(+), 9 deletions(-)

diff --git a/pywidevine/remotecdm.py b/pywidevine/remotecdm.py
index ae253eb..1786427 100644
--- a/pywidevine/remotecdm.py
+++ b/pywidevine/remotecdm.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import base64
 import binascii
-from typing import Union
+from typing import Union, Optional
 
 import requests
 from Crypto.PublicKey import RSA
@@ -117,7 +117,32 @@ class RemoteCdm(Cdm):
             raise ValueError(f"Cannot Close CDM Session, {r.text} [{r.status_code}]")
         del self._sessions[session_id]
 
-    # TODO: Implement set_service_certificate with /service_cert API schema
+    def set_service_certificate(self, session_id: bytes, certificate: Optional[Union[bytes, str]]) -> str:
+        session = self._sessions.get(session_id)
+        if not session:
+            raise InvalidSession(f"Session identifier {session_id!r} is invalid.")
+
+        if certificate is None:
+            certificate_b64 = None
+        elif isinstance(certificate, str):
+            certificate_b64 = certificate  # assuming base64
+        elif isinstance(certificate, bytes):
+            certificate_b64 = base64.b64encode(certificate).decode()
+        else:
+            raise DecodeError(f"Expecting Certificate to be base64 or bytes, not {certificate!r}")
+
+        r = self.__session.post(
+            url=f"{self.host}/{self.device_name}/set_service_certificate",
+            json={
+                "session_id": session_id.hex(),
+                "certificate": certificate_b64
+            }
+        )
+        if r.status_code != 200:
+            raise ValueError(f"Cannot Set CDMs Service Certificate, {r.text} [{r.status_code}]")
+        r = r.json()["data"]
+
+        return r["provider_id"]
 
     def get_license_challenge(
         self,
@@ -149,17 +174,11 @@ class RemoteCdm(Cdm):
         except ValueError:
             raise InvalidLicenseType(f"License Type {type_!r} is invalid")
 
-        if session.service_certificate:
-            service_certificate_b64 = base64.b64encode(session.service_certificate.SerializeToString()).decode()
-        else:
-            service_certificate_b64 = None
-
         r = self.__session.post(
             url=f"{self.host}/{self.device_name}/challenge/{type_}",
             json={
                 "session_id": session_id.hex(),
-                "init_data": base64.b64encode(init_data).decode(),
-                "service_certificate": service_certificate_b64
+                "init_data": base64.b64encode(init_data).decode()
             }
         )
         if r.status_code != 200: