浏览代码

fix(dependencies): improve typing

Carlo Sala 11 月之前
父节点
当前提交
a258eb4547
共有 1 个文件被更改,包括 21 次插入15 次删除
  1. 21 15
      .github/workflows/dependencies/updater.py

+ 21 - 15
.github/workflows/dependencies/updater.py

@@ -4,7 +4,7 @@ import subprocess
 import sys
 import timeit
 from copy import deepcopy
-from typing import Optional, TypedDict
+from typing import Literal, NotRequired, TypedDict
 
 import requests
 import yaml
@@ -49,20 +49,24 @@ class DependencyDict(TypedDict):
     repo: str
     branch: str
     version: str
-    precopy: Optional[str]
-    postcopy: Optional[str]
+    precopy: NotRequired[str]
+    postcopy: NotRequired[str]
 
 
 class DependencyYAML(TypedDict):
     dependencies: dict[str, DependencyDict]
 
 
-class UpdateStatus(TypedDict):
-    has_updates: bool
-    version: Optional[str]
-    compare_url: Optional[str]
-    head_ref: Optional[str]
-    head_url: Optional[str]
+class UpdateStatusFalse(TypedDict):
+    has_updates: Literal[False]
+
+
+class UpdateStatusTrue(TypedDict):
+    has_updates: Literal[True]
+    version: str
+    compare_url: str
+    head_ref: str
+    head_url: str
 
 
 class CommandRunner:
@@ -105,7 +109,9 @@ class DependencyStore:
         with CodeTimer(f"store deepcopy: {path}"):
             store_copy = deepcopy(DependencyStore.store)
 
-        dependency = store_copy["dependencies"].get(path, {})
+        dependency = store_copy["dependencies"].get(path)
+        if dependency is None:
+            raise ValueError(f"Dependency {path} {version} not found")
         dependency["version"] = version
         store_copy["dependencies"][path] = dependency
 
@@ -171,7 +177,7 @@ class Dependency:
                 else:
                     status = GitHub.check_updates(repo, remote_branch, version)
 
-            if status["has_updates"]:
+            if status["has_updates"] is True:
                 short_sha = status["head_ref"][:8]
                 new_version = status["version"] if is_tag else short_sha
 
@@ -212,10 +218,10 @@ Check out the [list of changes]({status['compare_url']}).
                         case CommandRunner.Exception:
                             # Print error message
                             print(
-                                f"Error running {e.stage} command: {e.returncode}",
+                                f"Error running {e.stage} command: {e.returncode}",  # pyright: ignore[reportAttributeAccessIssue]
                                 file=sys.stderr,
                             )
-                            print(e.stderr, file=sys.stderr)
+                            print(e.stderr, file=sys.stderr)  # pyright: ignore[reportAttributeAccessIssue]
                         case shutil.Error:
                             print(f"Error copying files: {e}", file=sys.stderr)
 
@@ -378,7 +384,7 @@ class Git:
 
 class GitHub:
     @staticmethod
-    def check_newer_tag(repo, current_tag) -> UpdateStatus:
+    def check_newer_tag(repo, current_tag) -> UpdateStatusFalse | UpdateStatusTrue:
         # GET /repos/:owner/:repo/git/refs/tags
         url = f"https://api.github.com/repos/{repo}/git/refs/tags"
 
@@ -417,7 +423,7 @@ class GitHub:
             )
 
     @staticmethod
-    def check_updates(repo, branch, version) -> UpdateStatus:
+    def check_updates(repo, branch, version) -> UpdateStatusFalse | UpdateStatusTrue:
         # TODO: add support for semver updating (based on tags)
         # Check if upstream github repo has a new version
         # GitHub API URL for comparing two commits