updater.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. import os
  2. import re
  3. import shutil
  4. import subprocess
  5. import sys
  6. import timeit
  7. from copy import deepcopy
  8. from typing import Literal, NotRequired, Optional, TypedDict
  9. import requests
  10. import yaml
  11. from semver import Version
  12. # Get TMP_DIR variable from environment
  13. TMP_DIR = os.path.join(os.environ.get("TMP_DIR", "/tmp"), "ohmyzsh")
  14. # Relative path to dependencies.yml file
  15. DEPS_YAML_FILE = ".github/dependencies.yml"
  16. # Dry run flag
  17. DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
  18. # utils for tag comparison
  19. BASEVERSION = re.compile(
  20. r"""[vV]?
  21. (?P<major>(0|[1-9])\d*)
  22. (\.
  23. (?P<minor>(0|[1-9])\d*)
  24. (\.
  25. (?P<patch>(0|[1-9])\d*)
  26. )?
  27. )?
  28. """,
  29. re.VERBOSE,
  30. )
  31. def coerce(version: str) -> Optional[Version]:
  32. match = BASEVERSION.search(version)
  33. if not match:
  34. return None
  35. # BASEVERSION looks for `MAJOR.minor.patch` in the string given
  36. # it fills with None if any of them is missing (for example `2.1`)
  37. ver = {
  38. key: 0 if value is None else value for key, value in match.groupdict().items()
  39. }
  40. # Version takes `major`, `minor`, `patch` arguments
  41. ver = Version(**ver) # pyright: ignore[reportArgumentType]
  42. return ver
  43. class CodeTimer:
  44. def __init__(self, name=None):
  45. self.name = " '" + name + "'" if name else ""
  46. def __enter__(self):
  47. self.start = timeit.default_timer()
  48. def __exit__(self, exc_type, exc_value, traceback):
  49. self.took = (timeit.default_timer() - self.start) * 1000.0
  50. print("Code block" + self.name + " took: " + str(self.took) + " ms")
  51. ### YAML representation
  52. def str_presenter(dumper, data):
  53. """
  54. Configures yaml for dumping multiline strings
  55. Ref: https://stackoverflow.com/a/33300001
  56. """
  57. if len(data.splitlines()) > 1: # check for multiline string
  58. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  59. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  60. yaml.add_representer(str, str_presenter)
  61. yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
  62. # Types
  63. class DependencyDict(TypedDict):
  64. repo: str
  65. branch: str
  66. version: str
  67. precopy: NotRequired[str]
  68. postcopy: NotRequired[str]
  69. class DependencyYAML(TypedDict):
  70. dependencies: dict[str, DependencyDict]
  71. class UpdateStatusFalse(TypedDict):
  72. has_updates: Literal[False]
  73. class UpdateStatusTrue(TypedDict):
  74. has_updates: Literal[True]
  75. version: str
  76. compare_url: str
  77. head_ref: str
  78. head_url: str
  79. class CommandRunner:
  80. class Exception(Exception):
  81. def __init__(self, message, returncode, stage, stdout, stderr):
  82. super().__init__(message)
  83. self.returncode = returncode
  84. self.stage = stage
  85. self.stdout = stdout
  86. self.stderr = stderr
  87. @staticmethod
  88. def run_or_fail(command: list[str], stage: str, *args, **kwargs):
  89. if DRY_RUN and command[0] == "gh":
  90. command.insert(0, "echo")
  91. result = subprocess.run(command, *args, capture_output=True, **kwargs)
  92. if result.returncode != 0:
  93. raise CommandRunner.Exception(
  94. f"{stage} command failed with exit code {result.returncode}",
  95. returncode=result.returncode,
  96. stage=stage,
  97. stdout=result.stdout.decode("utf-8"),
  98. stderr=result.stderr.decode("utf-8"),
  99. )
  100. return result
  101. class DependencyStore:
  102. store: DependencyYAML = {"dependencies": {}}
  103. @staticmethod
  104. def set(data: DependencyYAML):
  105. DependencyStore.store = data
  106. @staticmethod
  107. def update_dependency_version(path: str, version: str) -> DependencyYAML:
  108. with CodeTimer(f"store deepcopy: {path}"):
  109. store_copy = deepcopy(DependencyStore.store)
  110. dependency = store_copy["dependencies"].get(path)
  111. if dependency is None:
  112. raise ValueError(f"Dependency {path} {version} not found")
  113. dependency["version"] = version
  114. store_copy["dependencies"][path] = dependency
  115. return store_copy
  116. @staticmethod
  117. def write_store(file: str, data: DependencyYAML):
  118. with open(file, "w") as yaml_file:
  119. yaml.safe_dump(data, yaml_file, sort_keys=False)
  120. class Dependency:
  121. def __init__(self, path: str, values: DependencyDict):
  122. self.path = path
  123. self.values = values
  124. self.name: str = ""
  125. self.desc: str = ""
  126. self.kind: str = ""
  127. match path.split("/"):
  128. case ["plugins", name]:
  129. self.name = name
  130. self.kind = "plugin"
  131. self.desc = f"{name} plugin"
  132. case ["themes", name]:
  133. self.name = name.replace(".zsh-theme", "")
  134. self.kind = "theme"
  135. self.desc = f"{self.name} theme"
  136. case _:
  137. self.name = self.desc = path
  138. def __str__(self):
  139. output: str = ""
  140. for key in DependencyDict.__dict__["__annotations__"].keys():
  141. if key not in self.values:
  142. output += f"{key}: None\n"
  143. continue
  144. value = self.values[key]
  145. if "\n" not in value:
  146. output += f"{key}: {value}\n"
  147. else:
  148. output += f"{key}:\n "
  149. output += value.replace("\n", "\n ", value.count("\n") - 1)
  150. return output
  151. def update_or_notify(self):
  152. # Print dependency settings
  153. print(f"Processing {self.desc}...", file=sys.stderr)
  154. print(self, file=sys.stderr)
  155. # Check for updates
  156. repo = self.values["repo"]
  157. remote_branch = self.values["branch"]
  158. version = self.values["version"]
  159. is_tag = version.startswith("tag:")
  160. try:
  161. with CodeTimer(f"update check: {repo}"):
  162. if is_tag:
  163. status = GitHub.check_newer_tag(repo, version.replace("tag:", ""))
  164. else:
  165. status = GitHub.check_updates(repo, remote_branch, version)
  166. if status["has_updates"] is True:
  167. short_sha = status["head_ref"][:8]
  168. new_version = status["version"] if is_tag else short_sha
  169. try:
  170. # Create new branch
  171. branch = Git.create_branch(self.path, new_version)
  172. # Update dependencies.yml file
  173. self.__update_yaml(
  174. f"tag:{new_version}" if is_tag else status["version"]
  175. )
  176. # Update dependency files
  177. self.__apply_upstream_changes()
  178. # Add all changes and commit
  179. Git.add_and_commit(self.name, short_sha)
  180. # Push changes to remote
  181. Git.push(branch)
  182. # Create GitHub PR
  183. GitHub.create_pr(
  184. branch,
  185. f"feat({self.name}): update to version {new_version}",
  186. f"""## Description
  187. Update for **{self.desc}**: update to version [{new_version}]({status['head_url']}).
  188. Check out the [list of changes]({status['compare_url']}).
  189. """,
  190. )
  191. # Clean up repository
  192. Git.clean_repo()
  193. except (CommandRunner.Exception, shutil.Error) as e:
  194. # Handle exception on automatic update
  195. match type(e):
  196. case CommandRunner.Exception:
  197. # Print error message
  198. print(
  199. f"Error running {e.stage} command: {e.returncode}", # pyright: ignore[reportAttributeAccessIssue]
  200. file=sys.stderr,
  201. )
  202. print(e.stderr, file=sys.stderr) # pyright: ignore[reportAttributeAccessIssue]
  203. case shutil.Error:
  204. print(f"Error copying files: {e}", file=sys.stderr)
  205. try:
  206. Git.clean_repo()
  207. except CommandRunner.Exception as e:
  208. print(
  209. f"Error reverting repository to clean state: {e}",
  210. file=sys.stderr,
  211. )
  212. sys.exit(1)
  213. # Create a GitHub issue to notify maintainer
  214. title = f"{self.path}: update to {new_version}"
  215. body = f"""## Description
  216. There is a new version of `{self.name}` {self.kind} available.
  217. New version: [{new_version}]({status['head_url']})
  218. Check out the [list of changes]({status['compare_url']}).
  219. """
  220. print("Creating GitHub issue", file=sys.stderr)
  221. print(f"{title}\n\n{body}", file=sys.stderr)
  222. GitHub.create_issue(title, body)
  223. except Exception as e:
  224. print(e, file=sys.stderr)
  225. def __update_yaml(self, new_version: str) -> None:
  226. dep_yaml = DependencyStore.update_dependency_version(self.path, new_version)
  227. DependencyStore.write_store(DEPS_YAML_FILE, dep_yaml)
  228. def __apply_upstream_changes(self) -> None:
  229. # Patterns to ignore in copying files from upstream repo
  230. GLOBAL_IGNORE = [".git", ".github", ".gitignore"]
  231. path = os.path.abspath(self.path)
  232. precopy = self.values.get("precopy")
  233. postcopy = self.values.get("postcopy")
  234. repo = self.values["repo"]
  235. branch = self.values["branch"]
  236. remote_url = f"https://github.com/{repo}.git"
  237. repo_dir = os.path.join(TMP_DIR, repo)
  238. # Clone repository
  239. Git.clone(remote_url, branch, repo_dir, reclone=True)
  240. # Run precopy on tmp repo
  241. if precopy is not None:
  242. print("Running precopy script:", end="\n ", file=sys.stderr)
  243. print(
  244. precopy.replace("\n", "\n ", precopy.count("\n") - 1), file=sys.stderr
  245. )
  246. CommandRunner.run_or_fail(
  247. ["bash", "-c", precopy], cwd=repo_dir, stage="Precopy"
  248. )
  249. # Copy files from upstream repo
  250. print(f"Copying files from {repo_dir} to {path}", file=sys.stderr)
  251. shutil.copytree(
  252. repo_dir,
  253. path,
  254. dirs_exist_ok=True,
  255. ignore=shutil.ignore_patterns(*GLOBAL_IGNORE),
  256. )
  257. # Run postcopy on our repository
  258. if postcopy is not None:
  259. print("Running postcopy script:", end="\n ", file=sys.stderr)
  260. print(
  261. postcopy.replace("\n", "\n ", postcopy.count("\n") - 1),
  262. file=sys.stderr,
  263. )
  264. CommandRunner.run_or_fail(
  265. ["bash", "-c", postcopy], cwd=path, stage="Postcopy"
  266. )
  267. class Git:
  268. default_branch = "master"
  269. @staticmethod
  270. def clone(remote_url: str, branch: str, repo_dir: str, reclone=False):
  271. # If repo needs to be fresh
  272. if reclone and os.path.exists(repo_dir):
  273. shutil.rmtree(repo_dir)
  274. # Clone repo in tmp directory and checkout branch
  275. if not os.path.exists(repo_dir):
  276. print(
  277. f"Cloning {remote_url} to {repo_dir} and checking out {branch}",
  278. file=sys.stderr,
  279. )
  280. CommandRunner.run_or_fail(
  281. ["git", "clone", "--depth=1", "-b", branch, remote_url, repo_dir],
  282. stage="Clone",
  283. )
  284. @staticmethod
  285. def create_branch(path: str, version: str):
  286. # Get current branch name
  287. result = CommandRunner.run_or_fail(
  288. ["git", "rev-parse", "--abbrev-ref", "HEAD"], stage="GetDefaultBranch"
  289. )
  290. Git.default_branch = result.stdout.decode("utf-8").strip()
  291. # Create new branch and return created branch name
  292. branch_name = f"update/{path}/{version}"
  293. CommandRunner.run_or_fail(
  294. ["git", "checkout", "-b", branch_name], stage="CreateBranch"
  295. )
  296. return branch_name
  297. @staticmethod
  298. def add_and_commit(scope: str, version: str):
  299. user_name = os.environ.get("GIT_APP_NAME")
  300. user_email = os.environ.get("GIT_APP_EMAIL")
  301. # Add all files to git staging
  302. CommandRunner.run_or_fail(["git", "add", "-A", "-v"], stage="AddFiles")
  303. # Reset environment and git config
  304. clean_env = os.environ.copy()
  305. clean_env["LANG"] = "C.UTF-8"
  306. clean_env["GIT_CONFIG_GLOBAL"] = "/dev/null"
  307. clean_env["GIT_CONFIG_NOSYSTEM"] = "1"
  308. # Commit with settings above
  309. CommandRunner.run_or_fail(
  310. [
  311. "git",
  312. "-c",
  313. f"user.name={user_name}",
  314. "-c",
  315. f"user.email={user_email}",
  316. "commit",
  317. "-m",
  318. f"feat({scope}): update to {version}",
  319. ],
  320. stage="CreateCommit",
  321. env=clean_env,
  322. )
  323. @staticmethod
  324. def push(branch: str):
  325. CommandRunner.run_or_fail(
  326. ["git", "push", "-u", "origin", branch], stage="PushBranch"
  327. )
  328. @staticmethod
  329. def clean_repo():
  330. CommandRunner.run_or_fail(
  331. ["git", "reset", "--hard", "HEAD"], stage="ResetRepository"
  332. )
  333. CommandRunner.run_or_fail(
  334. ["git", "checkout", Git.default_branch], stage="CheckoutDefaultBranch"
  335. )
  336. class GitHub:
  337. @staticmethod
  338. def check_newer_tag(repo, current_tag) -> UpdateStatusFalse | UpdateStatusTrue:
  339. # GET /repos/:owner/:repo/git/refs/tags
  340. url = f"https://api.github.com/repos/{repo}/git/refs/tags"
  341. # Send a GET request to the GitHub API
  342. response = requests.get(url)
  343. current_version = coerce(current_tag)
  344. if current_version is None:
  345. raise ValueError(
  346. f"Stored {current_version} from {repo} does not follow semver"
  347. )
  348. # If the request was successful
  349. if response.status_code == 200:
  350. # Parse the JSON response
  351. data = response.json()
  352. if len(data) == 0:
  353. return {
  354. "has_updates": False,
  355. }
  356. latest_ref = None
  357. latest_version: Optional[Version] = None
  358. for ref in data:
  359. # we find the tag since GitHub returns it as plain git ref
  360. tag_version = coerce(ref["ref"].replace("refs/tags/", ""))
  361. if tag_version is None:
  362. # we skip every tag that is not semver-complaint
  363. continue
  364. if latest_version is None or tag_version.compare(latest_version) > 0:
  365. # if we have a "greater" semver version, set it as latest
  366. latest_version = tag_version
  367. latest_ref = ref
  368. # raise if no valid semver tag is found
  369. if latest_ref is None or latest_version is None:
  370. raise ValueError(f"No tags following semver found in {repo}")
  371. # we get the tag since GitHub returns it as plain git ref
  372. latest_tag = latest_ref["ref"].replace("refs/tags/", "")
  373. if latest_version.compare(current_version) <= 0:
  374. return {
  375. "has_updates": False,
  376. }
  377. return {
  378. "has_updates": True,
  379. "version": latest_tag,
  380. "compare_url": f"https://github.com/{repo}/compare/{current_tag}...{latest_tag}",
  381. "head_ref": latest_ref["object"]["sha"],
  382. "head_url": f"https://github.com/{repo}/releases/tag/{latest_tag}",
  383. }
  384. else:
  385. # If the request was not successful, raise an exception
  386. raise Exception(
  387. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  388. )
  389. @staticmethod
  390. def check_updates(repo, branch, version) -> UpdateStatusFalse | UpdateStatusTrue:
  391. url = f"https://api.github.com/repos/{repo}/compare/{version}...{branch}"
  392. # Send a GET request to the GitHub API
  393. response = requests.get(url)
  394. # If the request was successful
  395. if response.status_code == 200:
  396. # Parse the JSON response
  397. data = response.json()
  398. # If the base is behind the head, there is a newer version
  399. has_updates = data["status"] != "identical"
  400. if not has_updates:
  401. return {
  402. "has_updates": False,
  403. }
  404. return {
  405. "has_updates": data["status"] != "identical",
  406. "version": data["commits"][-1]["sha"],
  407. "compare_url": data["permalink_url"],
  408. "head_ref": data["commits"][-1]["sha"],
  409. "head_url": data["commits"][-1]["html_url"],
  410. }
  411. else:
  412. # If the request was not successful, raise an exception
  413. raise Exception(
  414. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  415. )
  416. @staticmethod
  417. def create_issue(title: str, body: str) -> None:
  418. cmd = ["gh", "issue", "create", "-t", title, "-b", body]
  419. CommandRunner.run_or_fail(cmd, stage="CreateIssue")
  420. @staticmethod
  421. def create_pr(branch: str, title: str, body: str) -> None:
  422. cmd = [
  423. "gh",
  424. "pr",
  425. "create",
  426. "-B",
  427. Git.default_branch,
  428. "-H",
  429. branch,
  430. "-t",
  431. title,
  432. "-b",
  433. body,
  434. ]
  435. CommandRunner.run_or_fail(cmd, stage="CreatePullRequest")
  436. def main():
  437. # Load the YAML file
  438. with open(DEPS_YAML_FILE, "r") as yaml_file:
  439. data: DependencyYAML = yaml.safe_load(yaml_file)
  440. if "dependencies" not in data:
  441. raise Exception("dependencies.yml not properly formatted")
  442. # Cache YAML version
  443. DependencyStore.set(data)
  444. dependencies = data["dependencies"]
  445. for path in dependencies:
  446. dependency = Dependency(path, dependencies[path])
  447. dependency.update_or_notify()
  448. if __name__ == "__main__":
  449. main()