updater.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. import json
  2. import os
  3. import re
  4. import shutil
  5. import subprocess
  6. import sys
  7. import timeit
  8. from copy import deepcopy
  9. from typing import Literal, NotRequired, Optional, TypedDict
  10. import requests
  11. import yaml
  12. from semver import Version
  13. # Get TMP_DIR variable from environment
  14. TMP_DIR = os.path.join(os.environ.get("TMP_DIR", "/tmp"), "ohmyzsh")
  15. # Relative path to dependencies.yml file
  16. DEPS_YAML_FILE = ".github/dependencies.yml"
  17. # Dry run flag
  18. DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
  19. # utils for tag comparison
  20. BASEVERSION = re.compile(
  21. r"""[vV]?
  22. (?P<major>(0|[1-9])\d*)
  23. (\.
  24. (?P<minor>(0|[1-9])\d*)
  25. (\.
  26. (?P<patch>(0|[1-9])\d*)
  27. )?
  28. )?
  29. """,
  30. re.VERBOSE,
  31. )
  32. def coerce(version: str) -> Optional[Version]:
  33. match = BASEVERSION.search(version)
  34. if not match:
  35. return None
  36. # BASEVERSION looks for `MAJOR.minor.patch` in the string given
  37. # it fills with None if any of them is missing (for example `2.1`)
  38. ver = {
  39. key: 0 if value is None else value for key, value in match.groupdict().items()
  40. }
  41. # Version takes `major`, `minor`, `patch` arguments
  42. ver = Version(**ver) # pyright: ignore[reportArgumentType]
  43. return ver
  44. class CodeTimer:
  45. def __init__(self, name=None):
  46. self.name = " '" + name + "'" if name else ""
  47. def __enter__(self):
  48. self.start = timeit.default_timer()
  49. def __exit__(self, exc_type, exc_value, traceback):
  50. self.took = (timeit.default_timer() - self.start) * 1000.0
  51. print("Code block" + self.name + " took: " + str(self.took) + " ms")
  52. ### YAML representation
  53. def str_presenter(dumper, data):
  54. """
  55. Configures yaml for dumping multiline strings
  56. Ref: https://stackoverflow.com/a/33300001
  57. """
  58. if len(data.splitlines()) > 1: # check for multiline string
  59. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  60. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  61. yaml.add_representer(str, str_presenter)
  62. yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
  63. # Types
  64. class DependencyDict(TypedDict):
  65. repo: str
  66. branch: str
  67. version: str
  68. precopy: NotRequired[str]
  69. postcopy: NotRequired[str]
  70. class DependencyYAML(TypedDict):
  71. dependencies: dict[str, DependencyDict]
  72. class UpdateStatusFalse(TypedDict):
  73. has_updates: Literal[False]
  74. class UpdateStatusTrue(TypedDict):
  75. has_updates: Literal[True]
  76. version: str
  77. compare_url: str
  78. head_ref: str
  79. head_url: str
  80. class CommandRunner:
  81. class Exception(Exception):
  82. def __init__(self, message, returncode, stage, stdout, stderr):
  83. super().__init__(message)
  84. self.returncode = returncode
  85. self.stage = stage
  86. self.stdout = stdout
  87. self.stderr = stderr
  88. @staticmethod
  89. def run_or_fail(command: list[str], stage: str, *args, **kwargs):
  90. if DRY_RUN and command[0] == "gh":
  91. command.insert(0, "echo")
  92. result = subprocess.run(command, *args, capture_output=True, **kwargs)
  93. if result.returncode != 0:
  94. raise CommandRunner.Exception(
  95. f"{stage} command failed with exit code {result.returncode}",
  96. returncode=result.returncode,
  97. stage=stage,
  98. stdout=result.stdout.decode("utf-8"),
  99. stderr=result.stderr.decode("utf-8"),
  100. )
  101. return result
  102. class DependencyStore:
  103. store: DependencyYAML = {"dependencies": {}}
  104. @staticmethod
  105. def set(data: DependencyYAML):
  106. DependencyStore.store = data
  107. @staticmethod
  108. def update_dependency_version(path: str, version: str) -> DependencyYAML:
  109. with CodeTimer(f"store deepcopy: {path}"):
  110. store_copy = deepcopy(DependencyStore.store)
  111. dependency = store_copy["dependencies"].get(path)
  112. if dependency is None:
  113. raise ValueError(f"Dependency {path} {version} not found")
  114. dependency["version"] = version
  115. store_copy["dependencies"][path] = dependency
  116. return store_copy
  117. @staticmethod
  118. def write_store(file: str, data: DependencyYAML):
  119. with open(file, "w") as yaml_file:
  120. yaml.safe_dump(data, yaml_file, sort_keys=False)
  121. class Dependency:
  122. def __init__(self, path: str, values: DependencyDict):
  123. self.path = path
  124. self.values = values
  125. self.name: str = ""
  126. self.desc: str = ""
  127. self.kind: str = ""
  128. match path.split("/"):
  129. case ["plugins", name]:
  130. self.name = name
  131. self.kind = "plugin"
  132. self.desc = f"{name} plugin"
  133. case ["themes", name]:
  134. self.name = name.replace(".zsh-theme", "")
  135. self.kind = "theme"
  136. self.desc = f"{self.name} theme"
  137. case _:
  138. self.name = self.desc = path
  139. def __str__(self):
  140. output: str = ""
  141. for key in DependencyDict.__dict__["__annotations__"].keys():
  142. if key not in self.values:
  143. output += f"{key}: None\n"
  144. continue
  145. value = self.values[key]
  146. if "\n" not in value:
  147. output += f"{key}: {value}\n"
  148. else:
  149. output += f"{key}:\n "
  150. output += value.replace("\n", "\n ", value.count("\n") - 1)
  151. return output
  152. def update_or_notify(self):
  153. # Print dependency settings
  154. print(f"Processing {self.desc}...", file=sys.stderr)
  155. print(self, file=sys.stderr)
  156. # Check for updates
  157. repo = self.values["repo"]
  158. remote_branch = self.values["branch"]
  159. version = self.values["version"]
  160. is_tag = version.startswith("tag:")
  161. try:
  162. with CodeTimer(f"update check: {repo}"):
  163. if is_tag:
  164. status = GitHub.check_newer_tag(repo, version.replace("tag:", ""))
  165. else:
  166. status = GitHub.check_updates(repo, remote_branch, version)
  167. if status["has_updates"] is True:
  168. short_sha = status["head_ref"][:8]
  169. new_version = status["version"] if is_tag else short_sha
  170. try:
  171. branch_name = f"update/{self.path}/{new_version}"
  172. # Create new branch
  173. branch = Git.checkout_or_create_branch(branch_name)
  174. # Update dependencies.yml file
  175. self.__update_yaml(
  176. f"tag:{new_version}" if is_tag else status["version"]
  177. )
  178. # Update dependency files
  179. self.__apply_upstream_changes()
  180. # Add all changes and commit
  181. Git.add_and_commit(self.name, short_sha)
  182. # Push changes to remote
  183. Git.push(branch)
  184. # Create GitHub PR
  185. GitHub.create_pr(
  186. branch,
  187. f"feat({self.name}): update to version {new_version}",
  188. f"""## Description
  189. Update for **{self.desc}**: update to version [{new_version}]({status['head_url']}).
  190. Check out the [list of changes]({status['compare_url']}).
  191. """,
  192. )
  193. # Clean up repository
  194. Git.clean_repo()
  195. except (CommandRunner.Exception, shutil.Error) as e:
  196. # Handle exception on automatic update
  197. match type(e):
  198. case CommandRunner.Exception:
  199. # Print error message
  200. print(
  201. f"Error running {e.stage} command: {e.returncode}", # pyright: ignore[reportAttributeAccessIssue]
  202. file=sys.stderr,
  203. )
  204. print(e.stderr, file=sys.stderr) # pyright: ignore[reportAttributeAccessIssue]
  205. case shutil.Error:
  206. print(f"Error copying files: {e}", file=sys.stderr)
  207. try:
  208. Git.clean_repo()
  209. except CommandRunner.Exception as e:
  210. print(
  211. f"Error reverting repository to clean state: {e}",
  212. file=sys.stderr,
  213. )
  214. sys.exit(1)
  215. # Create a GitHub issue to notify maintainer
  216. title = f"{self.path}: update to {new_version}"
  217. body = f"""## Description
  218. There is a new version of `{self.name}` {self.kind} available.
  219. New version: [{new_version}]({status['head_url']})
  220. Check out the [list of changes]({status['compare_url']}).
  221. """
  222. print("Creating GitHub issue", file=sys.stderr)
  223. print(f"{title}\n\n{body}", file=sys.stderr)
  224. GitHub.create_issue(title, body)
  225. except Exception as e:
  226. print(e, file=sys.stderr)
  227. def __update_yaml(self, new_version: str) -> None:
  228. dep_yaml = DependencyStore.update_dependency_version(self.path, new_version)
  229. DependencyStore.write_store(DEPS_YAML_FILE, dep_yaml)
  230. def __apply_upstream_changes(self) -> None:
  231. # Patterns to ignore in copying files from upstream repo
  232. GLOBAL_IGNORE = [".git", ".github", ".gitignore"]
  233. path = os.path.abspath(self.path)
  234. precopy = self.values.get("precopy")
  235. postcopy = self.values.get("postcopy")
  236. repo = self.values["repo"]
  237. branch = self.values["branch"]
  238. remote_url = f"https://github.com/{repo}.git"
  239. repo_dir = os.path.join(TMP_DIR, repo)
  240. # Clone repository
  241. Git.clone(remote_url, branch, repo_dir, reclone=True)
  242. # Run precopy on tmp repo
  243. if precopy is not None:
  244. print("Running precopy script:", end="\n ", file=sys.stderr)
  245. print(
  246. precopy.replace("\n", "\n ", precopy.count("\n") - 1), file=sys.stderr
  247. )
  248. CommandRunner.run_or_fail(
  249. ["bash", "-c", precopy], cwd=repo_dir, stage="Precopy"
  250. )
  251. # Copy files from upstream repo
  252. print(f"Copying files from {repo_dir} to {path}", file=sys.stderr)
  253. shutil.copytree(
  254. repo_dir,
  255. path,
  256. dirs_exist_ok=True,
  257. ignore=shutil.ignore_patterns(*GLOBAL_IGNORE),
  258. )
  259. # Run postcopy on our repository
  260. if postcopy is not None:
  261. print("Running postcopy script:", end="\n ", file=sys.stderr)
  262. print(
  263. postcopy.replace("\n", "\n ", postcopy.count("\n") - 1),
  264. file=sys.stderr,
  265. )
  266. CommandRunner.run_or_fail(
  267. ["bash", "-c", postcopy], cwd=path, stage="Postcopy"
  268. )
  269. class Git:
  270. default_branch = "master"
  271. @staticmethod
  272. def clone(remote_url: str, branch: str, repo_dir: str, reclone=False):
  273. # If repo needs to be fresh
  274. if reclone and os.path.exists(repo_dir):
  275. shutil.rmtree(repo_dir)
  276. # Clone repo in tmp directory and checkout branch
  277. if not os.path.exists(repo_dir):
  278. print(
  279. f"Cloning {remote_url} to {repo_dir} and checking out {branch}",
  280. file=sys.stderr,
  281. )
  282. CommandRunner.run_or_fail(
  283. ["git", "clone", "--depth=1", "-b", branch, remote_url, repo_dir],
  284. stage="Clone",
  285. )
  286. @staticmethod
  287. def checkout_or_create_branch(branch_name: str):
  288. # Get current branch name
  289. result = CommandRunner.run_or_fail(
  290. ["git", "rev-parse", "--abbrev-ref", "HEAD"], stage="GetDefaultBranch"
  291. )
  292. Git.default_branch = result.stdout.decode("utf-8").strip()
  293. # Create new branch and return created branch name
  294. try:
  295. # try to checkout already existing branch
  296. CommandRunner.run_or_fail(
  297. ["git", "checkout", branch_name], stage="CreateBranch"
  298. )
  299. except CommandRunner.Exception:
  300. # otherwise create new branch
  301. CommandRunner.run_or_fail(
  302. ["git", "checkout", "-b", branch_name], stage="CreateBranch"
  303. )
  304. return branch_name
  305. @staticmethod
  306. def add_and_commit(scope: str, version: str):
  307. user_name = os.environ.get("GIT_APP_NAME")
  308. user_email = os.environ.get("GIT_APP_EMAIL")
  309. # Add all files to git staging
  310. CommandRunner.run_or_fail(["git", "add", "-A", "-v"], stage="AddFiles")
  311. # Reset environment and git config
  312. clean_env = os.environ.copy()
  313. clean_env["LANG"] = "C.UTF-8"
  314. clean_env["GIT_CONFIG_GLOBAL"] = "/dev/null"
  315. clean_env["GIT_CONFIG_NOSYSTEM"] = "1"
  316. # check if repo is clean (clean => no error, no commit)
  317. try:
  318. CommandRunner.run_or_fail(
  319. ["git", "diff", "--exit-code"], stage="CheckRepoClean", env=clean_env
  320. )
  321. except CommandRunner.Exception:
  322. # Commit with settings above
  323. CommandRunner.run_or_fail(
  324. [
  325. "git",
  326. "-c",
  327. f"user.name={user_name}",
  328. "-c",
  329. f"user.email={user_email}",
  330. "commit",
  331. "-m",
  332. f"feat({scope}): update to {version}",
  333. ],
  334. stage="CreateCommit",
  335. env=clean_env,
  336. )
  337. @staticmethod
  338. def push(branch: str):
  339. CommandRunner.run_or_fail(
  340. ["git", "push", "-u", "origin", branch], stage="PushBranch"
  341. )
  342. @staticmethod
  343. def clean_repo():
  344. CommandRunner.run_or_fail(
  345. ["git", "reset", "--hard", "HEAD"], stage="ResetRepository"
  346. )
  347. CommandRunner.run_or_fail(
  348. ["git", "checkout", Git.default_branch], stage="CheckoutDefaultBranch"
  349. )
  350. class GitHub:
  351. @staticmethod
  352. def check_newer_tag(repo, current_tag) -> UpdateStatusFalse | UpdateStatusTrue:
  353. # GET /repos/:owner/:repo/git/refs/tags
  354. url = f"https://api.github.com/repos/{repo}/git/refs/tags"
  355. # Send a GET request to the GitHub API
  356. response = requests.get(url)
  357. current_version = coerce(current_tag)
  358. if current_version is None:
  359. raise ValueError(
  360. f"Stored {current_version} from {repo} does not follow semver"
  361. )
  362. # If the request was successful
  363. if response.status_code == 200:
  364. # Parse the JSON response
  365. data = response.json()
  366. if len(data) == 0:
  367. return {
  368. "has_updates": False,
  369. }
  370. latest_ref = None
  371. latest_version: Optional[Version] = None
  372. for ref in data:
  373. # we find the tag since GitHub returns it as plain git ref
  374. tag_version = coerce(ref["ref"].replace("refs/tags/", ""))
  375. if tag_version is None:
  376. # we skip every tag that is not semver-complaint
  377. continue
  378. if latest_version is None or tag_version.compare(latest_version) > 0:
  379. # if we have a "greater" semver version, set it as latest
  380. latest_version = tag_version
  381. latest_ref = ref
  382. # raise if no valid semver tag is found
  383. if latest_ref is None or latest_version is None:
  384. raise ValueError(f"No tags following semver found in {repo}")
  385. # we get the tag since GitHub returns it as plain git ref
  386. latest_tag = latest_ref["ref"].replace("refs/tags/", "")
  387. if latest_version.compare(current_version) <= 0:
  388. return {
  389. "has_updates": False,
  390. }
  391. return {
  392. "has_updates": True,
  393. "version": latest_tag,
  394. "compare_url": f"https://github.com/{repo}/compare/{current_tag}...{latest_tag}",
  395. "head_ref": latest_ref["object"]["sha"],
  396. "head_url": f"https://github.com/{repo}/releases/tag/{latest_tag}",
  397. }
  398. else:
  399. # If the request was not successful, raise an exception
  400. raise Exception(
  401. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  402. )
  403. @staticmethod
  404. def check_updates(repo, branch, version) -> UpdateStatusFalse | UpdateStatusTrue:
  405. url = f"https://api.github.com/repos/{repo}/compare/{version}...{branch}"
  406. # Send a GET request to the GitHub API
  407. response = requests.get(url)
  408. # If the request was successful
  409. if response.status_code == 200:
  410. # Parse the JSON response
  411. data = response.json()
  412. # If the base is behind the head, there is a newer version
  413. has_updates = data["status"] != "identical"
  414. if not has_updates:
  415. return {
  416. "has_updates": False,
  417. }
  418. return {
  419. "has_updates": data["status"] != "identical",
  420. "version": data["commits"][-1]["sha"],
  421. "compare_url": data["permalink_url"],
  422. "head_ref": data["commits"][-1]["sha"],
  423. "head_url": data["commits"][-1]["html_url"],
  424. }
  425. else:
  426. # If the request was not successful, raise an exception
  427. raise Exception(
  428. f"GitHub API request failed with status code {response.status_code}: {response.json()}"
  429. )
  430. @staticmethod
  431. def create_issue(title: str, body: str) -> None:
  432. cmd = ["gh", "issue", "create", "-t", title, "-b", body]
  433. CommandRunner.run_or_fail(cmd, stage="CreateIssue")
  434. @staticmethod
  435. def create_pr(branch: str, title: str, body: str) -> None:
  436. # first of all let's check if PR is already open
  437. check_cmd = [
  438. "gh",
  439. "pr",
  440. "list",
  441. "--state",
  442. "open",
  443. "--head",
  444. branch,
  445. "--json",
  446. "title",
  447. ]
  448. # returncode is 0 also if no PRs are found
  449. output = json.loads(
  450. CommandRunner.run_or_fail(check_cmd, stage="CheckPullRequestOpen")
  451. .stdout.decode("utf-8")
  452. .strip()
  453. )
  454. # we have PR in this case!
  455. if len(output) > 0:
  456. return
  457. cmd = [
  458. "gh",
  459. "pr",
  460. "create",
  461. "-B",
  462. Git.default_branch,
  463. "-H",
  464. branch,
  465. "-t",
  466. title,
  467. "-b",
  468. body,
  469. ]
  470. CommandRunner.run_or_fail(cmd, stage="CreatePullRequest")
  471. def main():
  472. # Load the YAML file
  473. with open(DEPS_YAML_FILE, "r") as yaml_file:
  474. data: DependencyYAML = yaml.safe_load(yaml_file)
  475. if "dependencies" not in data:
  476. raise Exception("dependencies.yml not properly formatted")
  477. # Cache YAML version
  478. DependencyStore.set(data)
  479. dependencies = data["dependencies"]
  480. for path in dependencies:
  481. dependency = Dependency(path, dependencies[path])
  482. dependency.update_or_notify()
  483. if __name__ == "__main__":
  484. main()