updater.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. import json
  2. import os
  3. import re
  4. import shutil
  5. import subprocess
  6. import sys
  7. import timeit
  8. from copy import deepcopy
  9. from typing import Literal, NotRequired, Optional, TypedDict
  10. import requests
  11. import yaml
  12. from semver import Version
  13. # Get TMP_DIR variable from environment
  14. TMP_DIR = os.path.join(os.environ.get("TMP_DIR", "/tmp"), "ohmyzsh")
  15. # Relative path to dependencies.yml file
  16. DEPS_YAML_FILE = ".github/dependencies.yml"
  17. # Dry run flag
  18. DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
  19. # utils for tag comparison
  20. BASEVERSION = re.compile(
  21. r"""[vV]?
  22. (?P<major>(0|[1-9])\d*)
  23. (\.
  24. (?P<minor>(0|[1-9])\d*)
  25. (\.
  26. (?P<patch>(0|[1-9])\d*)
  27. )?
  28. )?
  29. """,
  30. re.VERBOSE,
  31. )
  32. def coerce(version: str) -> Optional[Version]:
  33. match = BASEVERSION.search(version)
  34. if not match:
  35. return None
  36. # BASEVERSION looks for `MAJOR.minor.patch` in the string given
  37. # it fills with None if any of them is missing (for example `2.1`)
  38. ver = {
  39. key: 0 if value is None else value for key, value in match.groupdict().items()
  40. }
  41. # Version takes `major`, `minor`, `patch` arguments
  42. ver = Version(**ver) # pyright: ignore[reportArgumentType]
  43. return ver
  44. class CodeTimer:
  45. def __init__(self, name=None):
  46. self.name = " '" + name + "'" if name else ""
  47. def __enter__(self):
  48. self.start = timeit.default_timer()
  49. def __exit__(self, exc_type, exc_value, traceback):
  50. self.took = (timeit.default_timer() - self.start) * 1000.0
  51. print("Code block" + self.name + " took: " + str(self.took) + " ms")
  52. ### YAML representation
  53. def str_presenter(dumper, data):
  54. """
  55. Configures yaml for dumping multiline strings
  56. Ref: https://stackoverflow.com/a/33300001
  57. """
  58. if len(data.splitlines()) > 1: # check for multiline string
  59. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  60. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  61. yaml.add_representer(str, str_presenter)
  62. yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
  63. # Types
  64. class DependencyDict(TypedDict):
  65. repo: str
  66. branch: str
  67. version: str
  68. precopy: NotRequired[str]
  69. postcopy: NotRequired[str]
  70. class DependencyYAML(TypedDict):
  71. dependencies: dict[str, DependencyDict]
  72. class UpdateStatusFalse(TypedDict):
  73. has_updates: Literal[False]
  74. class UpdateStatusTrue(TypedDict):
  75. has_updates: Literal[True]
  76. version: str
  77. compare_url: str
  78. head_ref: str
  79. head_url: str
  80. class CommandRunner:
  81. class Exception(Exception):
  82. def __init__(self, message, returncode, stage, stdout, stderr):
  83. super().__init__(message)
  84. self.returncode = returncode
  85. self.stage = stage
  86. self.stdout = stdout
  87. self.stderr = stderr
  88. @staticmethod
  89. def run_or_fail(command: list[str], stage: str, *args, **kwargs):
  90. if DRY_RUN and command[0] == "gh":
  91. command.insert(0, "echo")
  92. result = subprocess.run(command, *args, capture_output=True, **kwargs)
  93. if result.returncode != 0:
  94. raise CommandRunner.Exception(
  95. f"{stage} command failed with exit code {result.returncode}",
  96. returncode=result.returncode,
  97. stage=stage,
  98. stdout=result.stdout.decode("utf-8"),
  99. stderr=result.stderr.decode("utf-8"),
  100. )
  101. return result
  102. class DependencyStore:
  103. store: DependencyYAML = {"dependencies": {}}
  104. @staticmethod
  105. def set(data: DependencyYAML):
  106. DependencyStore.store = data
  107. @staticmethod
  108. def update_dependency_version(path: str, version: str) -> DependencyYAML:
  109. with CodeTimer(f"store deepcopy: {path}"):
  110. store_copy = deepcopy(DependencyStore.store)
  111. dependency = store_copy["dependencies"].get(path)
  112. if dependency is None:
  113. raise ValueError(f"Dependency {path} {version} not found")
  114. dependency["version"] = version
  115. store_copy["dependencies"][path] = dependency
  116. return store_copy
  117. @staticmethod
  118. def write_store(file: str, data: DependencyYAML):
  119. with open(file, "w") as yaml_file:
  120. yaml.safe_dump(data, yaml_file, sort_keys=False)
  121. class Dependency:
  122. def __init__(self, path: str, values: DependencyDict):
  123. self.path = path
  124. self.values = values
  125. self.name: str = ""
  126. self.desc: str = ""
  127. self.kind: str = ""
  128. match path.split("/"):
  129. case ["plugins", name]:
  130. self.name = name
  131. self.kind = "plugin"
  132. self.desc = f"{name} plugin"
  133. case ["themes", name]:
  134. self.name = name.replace(".zsh-theme", "")
  135. self.kind = "theme"
  136. self.desc = f"{self.name} theme"
  137. case _:
  138. self.name = self.desc = path
  139. def __str__(self):
  140. output: str = ""
  141. for key in DependencyDict.__dict__["__annotations__"].keys():
  142. if key not in self.values:
  143. output += f"{key}: None\n"
  144. continue
  145. value = self.values[key]
  146. if "\n" not in value:
  147. output += f"{key}: {value}\n"
  148. else:
  149. output += f"{key}:\n "
  150. output += value.replace("\n", "\n ", value.count("\n") - 1)
  151. return output
  152. def update_or_notify(self):
  153. # Print dependency settings
  154. print(f"Processing {self.desc}...", file=sys.stderr)
  155. print(self, file=sys.stderr)
  156. # Check for updates
  157. repo = self.values["repo"]
  158. remote_branch = self.values["branch"]
  159. version = self.values["version"]
  160. is_tag = version.startswith("tag:")
  161. try:
  162. with CodeTimer(f"update check: {repo}"):
  163. if is_tag:
  164. status = GitHub.check_newer_tag(repo, version.replace("tag:", ""))
  165. else:
  166. status = GitHub.check_updates(repo, remote_branch, version)
  167. if status["has_updates"] is True:
  168. short_sha = status["head_ref"][:8]
  169. new_version = status["version"] if is_tag else short_sha
  170. try:
  171. branch_name = f"update/{self.path}/{new_version}"
  172. # Create new branch
  173. branch = Git.checkout_or_create_branch(branch_name)
  174. # Update dependencies.yml file
  175. self.__update_yaml(
  176. f"tag:{new_version}" if is_tag else status["version"]
  177. )
  178. # Update dependency files
  179. self.__apply_upstream_changes()
  180. # Add all changes and commit
  181. has_new_commit = Git.add_and_commit(self.name, short_sha)
  182. if has_new_commit:
  183. # Push changes to remote
  184. Git.push(branch)
  185. # Create GitHub PR
  186. GitHub.create_pr(
  187. branch,
  188. f"feat({self.name}): update to version {new_version}",
  189. f"""## Description
  190. Update for **{self.desc}**: update to version [{new_version}]({status['head_url']}).
  191. Check out the [list of changes]({status['compare_url']}).
  192. """,
  193. )
  194. # Clean up repository
  195. Git.clean_repo()
  196. except (CommandRunner.Exception, shutil.Error) as e:
  197. # Handle exception on automatic update
  198. match type(e):
  199. case CommandRunner.Exception:
  200. # Print error message
  201. print(
  202. f"Error running {e.stage} command: {e.returncode}", # pyright: ignore[reportAttributeAccessIssue]
  203. file=sys.stderr,
  204. )
  205. print(e.stderr, file=sys.stderr) # pyright: ignore[reportAttributeAccessIssue]
  206. case shutil.Error:
  207. print(f"Error copying files: {e}", file=sys.stderr)
  208. try:
  209. Git.clean_repo()
  210. except CommandRunner.Exception as e:
  211. print(
  212. f"Error reverting repository to clean state: {e}",
  213. file=sys.stderr,
  214. )
  215. sys.exit(1)
  216. # Create a GitHub issue to notify maintainer
  217. title = f"{self.path}: update to {new_version}"
  218. body = f"""## Description
  219. There is a new version of `{self.name}` {self.kind} available.
  220. New version: [{new_version}]({status['head_url']})
  221. Check out the [list of changes]({status['compare_url']}).
  222. """
  223. print("Creating GitHub issue", file=sys.stderr)
  224. print(f"{title}\n\n{body}", file=sys.stderr)
  225. GitHub.create_issue(title, body)
  226. except Exception as e:
  227. print(e, file=sys.stderr)
  228. def __update_yaml(self, new_version: str) -> None:
  229. dep_yaml = DependencyStore.update_dependency_version(self.path, new_version)
  230. DependencyStore.write_store(DEPS_YAML_FILE, dep_yaml)
  231. def __apply_upstream_changes(self) -> None:
  232. # Patterns to ignore in copying files from upstream repo
  233. GLOBAL_IGNORE = [".git", ".github", ".gitignore"]
  234. path = os.path.abspath(self.path)
  235. precopy = self.values.get("precopy")
  236. postcopy = self.values.get("postcopy")
  237. repo = self.values["repo"]
  238. branch = self.values["branch"]
  239. remote_url = f"https://github.com/{repo}.git"
  240. repo_dir = os.path.join(TMP_DIR, repo)
  241. # Clone repository
  242. Git.clone(remote_url, branch, repo_dir, reclone=True)
  243. # Run precopy on tmp repo
  244. if precopy is not None:
  245. print("Running precopy script:", end="\n ", file=sys.stderr)
  246. print(
  247. precopy.replace("\n", "\n ", precopy.count("\n") - 1), file=sys.stderr
  248. )
  249. CommandRunner.run_or_fail(
  250. ["bash", "-c", precopy], cwd=repo_dir, stage="Precopy"
  251. )
  252. # Copy files from upstream repo
  253. print(f"Copying files from {repo_dir} to {path}", file=sys.stderr)
  254. shutil.copytree(
  255. repo_dir,
  256. path,
  257. dirs_exist_ok=True,
  258. ignore=shutil.ignore_patterns(*GLOBAL_IGNORE),
  259. )
  260. # Run postcopy on our repository
  261. if postcopy is not None:
  262. print("Running postcopy script:", end="\n ", file=sys.stderr)
  263. print(
  264. postcopy.replace("\n", "\n ", postcopy.count("\n") - 1),
  265. file=sys.stderr,
  266. )
  267. CommandRunner.run_or_fail(
  268. ["bash", "-c", postcopy], cwd=path, stage="Postcopy"
  269. )
  270. class Git:
  271. default_branch = "master"
  272. @staticmethod
  273. def clone(remote_url: str, branch: str, repo_dir: str, reclone=False):
  274. # If repo needs to be fresh
  275. if reclone and os.path.exists(repo_dir):
  276. shutil.rmtree(repo_dir)
  277. # Clone repo in tmp directory and checkout branch
  278. if not os.path.exists(repo_dir):
  279. print(
  280. f"Cloning {remote_url} to {repo_dir} and checking out {branch}",
  281. file=sys.stderr,
  282. )
  283. CommandRunner.run_or_fail(
  284. ["git", "clone", "--depth=1", "-b", branch, remote_url, repo_dir],
  285. stage="Clone",
  286. )
  287. @staticmethod
  288. def checkout_or_create_branch(branch_name: str):
  289. # Get current branch name
  290. result = CommandRunner.run_or_fail(
  291. ["git", "rev-parse", "--abbrev-ref", "HEAD"], stage="GetDefaultBranch"
  292. )
  293. Git.default_branch = result.stdout.decode("utf-8").strip()
  294. # Create new branch and return created branch name
  295. try:
  296. # try to checkout already existing branch
  297. CommandRunner.run_or_fail(
  298. ["git", "checkout", branch_name], stage="CreateBranch"
  299. )
  300. except CommandRunner.Exception:
  301. # otherwise create new branch
  302. CommandRunner.run_or_fail(
  303. ["git", "checkout", "-b", branch_name], stage="CreateBranch"
  304. )
  305. return branch_name
  306. @staticmethod
  307. def add_and_commit(scope: str, version: str) -> bool:
  308. """
  309. Returns `True` if there were changes and were indeed commited.
  310. Returns `False` if the repo was clean and no changes were commited.
  311. """
  312. # check if repo is clean (clean => no error, no commit)
  313. try:
  314. CommandRunner.run_or_fail(
  315. ["git", "diff", "--exit-code"], stage="CheckRepoClean"
  316. )
  317. return False
  318. except CommandRunner.Exception:
  319. # if it's other kind of error just throw!
  320. pass
  321. user_name = os.environ.get("GIT_APP_NAME")
  322. user_email = os.environ.get("GIT_APP_EMAIL")
  323. # Add all files to git staging
  324. CommandRunner.run_or_fail(["git", "add", "-A", "-v"], stage="AddFiles")
  325. # Reset environment and git config
  326. clean_env = os.environ.copy()
  327. clean_env["LANG"] = "C.UTF-8"
  328. clean_env["GIT_CONFIG_GLOBAL"] = "/dev/null"
  329. clean_env["GIT_CONFIG_NOSYSTEM"] = "1"
  330. # Commit with settings above
  331. CommandRunner.run_or_fail(
  332. [
  333. "git",
  334. "-c",
  335. f"user.name={user_name}",
  336. "-c",
  337. f"user.email={user_email}",
  338. "commit",
  339. "-m",
  340. f"feat({scope}): update to {version}",
  341. ],
  342. stage="CreateCommit",
  343. env=clean_env,
  344. )
  345. return True
  346. @staticmethod
  347. def push(branch: str):
  348. CommandRunner.run_or_fail(
  349. ["git", "push", "-u", "origin", branch], stage="PushBranch"
  350. )
  351. @staticmethod
  352. def clean_repo():
  353. CommandRunner.run_or_fail(
  354. ["git", "reset", "--hard", "HEAD"], stage="ResetRepository"
  355. )
  356. CommandRunner.run_or_fail(
  357. ["git", "checkout", Git.default_branch], stage="CheckoutDefaultBranch"
  358. )
  359. class GitHub:
  360. @staticmethod
  361. def check_newer_tag(repo, current_tag) -> UpdateStatusFalse | UpdateStatusTrue:
  362. # GET /repos/:owner/:repo/git/refs/tags
  363. url = f"https://api.github.com/repos/{repo}/git/refs/tags"
  364. # Send a GET request to the GitHub API
  365. response = requests.get(url)
  366. current_version = coerce(current_tag)
  367. if current_version is None:
  368. raise ValueError(
  369. f"Stored {current_version} from {repo} does not follow semver"
  370. )
  371. # If the request was successful
  372. if response.status_code == 200:
  373. # Parse the JSON response
  374. data = response.json()
  375. if len(data) == 0:
  376. return {
  377. "has_updates": False,
  378. }
  379. latest_ref = None
  380. latest_version: Optional[Version] = None
  381. for ref in data:
  382. # we find the tag since GitHub returns it as plain git ref
  383. tag_version = coerce(ref["ref"].replace("refs/tags/", ""))
  384. if tag_version is None:
  385. # we skip every tag that is not semver-complaint
  386. continue
  387. if latest_version is None or tag_version.compare(latest_version) > 0:
  388. # if we have a "greater" semver version, set it as latest
  389. latest_version = tag_version
  390. latest_ref = ref
  391. # raise if no valid semver tag is found
  392. if latest_ref is None or latest_version is None:
  393. raise ValueError(f"No tags following semver found in {repo}")
  394. # we get the tag since GitHub returns it as plain git ref
  395. latest_tag = latest_ref["ref"].replace("refs/tags/", "")
  396. if latest_version.compare(current_version) <= 0:
  397. return {
  398. "has_updates": False,
  399. }
  400. return {
  401. "has_updates": True,
  402. "version": latest_tag,
  403. "compare_url": f"https://github.com/{repo}/compare/{current_tag}...{latest_tag}",
  404. "head_ref": latest_ref["object"]["sha"],
  405. "head_url": f"https://github.com/{repo}/releases/tag/{latest_tag}",
  406. }
  407. else:
  408. # If the request was not successful, raise an exception
  409. raise Exception(
  410. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  411. )
  412. @staticmethod
  413. def check_updates(repo, branch, version) -> UpdateStatusFalse | UpdateStatusTrue:
  414. url = f"https://api.github.com/repos/{repo}/compare/{version}...{branch}"
  415. # Send a GET request to the GitHub API
  416. response = requests.get(url)
  417. # If the request was successful
  418. if response.status_code == 200:
  419. # Parse the JSON response
  420. data = response.json()
  421. # If the base is behind the head, there is a newer version
  422. has_updates = data["status"] != "identical"
  423. if not has_updates:
  424. return {
  425. "has_updates": False,
  426. }
  427. return {
  428. "has_updates": data["status"] != "identical",
  429. "version": data["commits"][-1]["sha"],
  430. "compare_url": data["permalink_url"],
  431. "head_ref": data["commits"][-1]["sha"],
  432. "head_url": data["commits"][-1]["html_url"],
  433. }
  434. else:
  435. # If the request was not successful, raise an exception
  436. raise Exception(
  437. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  438. )
  439. @staticmethod
  440. def create_issue(title: str, body: str) -> None:
  441. cmd = ["gh", "issue", "create", "-t", title, "-b", body]
  442. CommandRunner.run_or_fail(cmd, stage="CreateIssue")
  443. @staticmethod
  444. def create_pr(branch: str, title: str, body: str) -> None:
  445. # first of all let's check if PR is already open
  446. check_cmd = [
  447. "gh",
  448. "pr",
  449. "list",
  450. "--state",
  451. "open",
  452. "--head",
  453. branch,
  454. "--json",
  455. "title",
  456. ]
  457. # returncode is 0 also if no PRs are found
  458. output = json.loads(
  459. CommandRunner.run_or_fail(check_cmd, stage="CheckPullRequestOpen")
  460. .stdout.decode("utf-8")
  461. .strip()
  462. )
  463. # we have PR in this case!
  464. if len(output) > 0:
  465. return
  466. cmd = [
  467. "gh",
  468. "pr",
  469. "create",
  470. "-B",
  471. Git.default_branch,
  472. "-H",
  473. branch,
  474. "-t",
  475. title,
  476. "-b",
  477. body,
  478. ]
  479. CommandRunner.run_or_fail(cmd, stage="CreatePullRequest")
  480. def main():
  481. # Load the YAML file
  482. with open(DEPS_YAML_FILE, "r") as yaml_file:
  483. data: DependencyYAML = yaml.safe_load(yaml_file)
  484. if "dependencies" not in data:
  485. raise Exception("dependencies.yml not properly formatted")
  486. # Cache YAML version
  487. DependencyStore.set(data)
  488. dependencies = data["dependencies"]
  489. for path in dependencies:
  490. dependency = Dependency(path, dependencies[path])
  491. dependency.update_or_notify()
  492. if __name__ == "__main__":
  493. main()