updater.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import os
  2. import subprocess
  3. import sys
  4. import requests
  5. import shutil
  6. import yaml
  7. from copy import deepcopy
  8. from typing import Optional, TypedDict
  9. # Get TMP_DIR variable from environment
  10. TMP_DIR = os.path.join(os.environ.get("TMP_DIR", "/tmp"), "ohmyzsh")
  11. # Relative path to dependencies.yml file
  12. DEPS_YAML_FILE = ".github/dependencies.yml"
  13. # Dry run flag
  14. DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
  15. import timeit
  16. class CodeTimer:
  17. def __init__(self, name=None):
  18. self.name = " '" + name + "'" if name else ''
  19. def __enter__(self):
  20. self.start = timeit.default_timer()
  21. def __exit__(self, exc_type, exc_value, traceback):
  22. self.took = (timeit.default_timer() - self.start) * 1000.0
  23. print('Code block' + self.name + ' took: ' + str(self.took) + ' ms')
  24. ### YAML representation
  25. def str_presenter(dumper, data):
  26. """
  27. Configures yaml for dumping multiline strings
  28. Ref: https://stackoverflow.com/a/33300001
  29. """
  30. if len(data.splitlines()) > 1: # check for multiline string
  31. return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
  32. return dumper.represent_scalar('tag:yaml.org,2002:str', data)
  33. yaml.add_representer(str, str_presenter)
  34. yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
  35. # Types
  36. class DependencyDict(TypedDict):
  37. repo: str
  38. branch: str
  39. version: str
  40. precopy: Optional[str]
  41. postcopy: Optional[str]
  42. class DependencyYAML(TypedDict):
  43. dependencies: dict[str, DependencyDict]
  44. class UpdateStatus(TypedDict):
  45. has_updates: bool
  46. version: Optional[str]
  47. compare_url: Optional[str]
  48. head_ref: Optional[str]
  49. head_url: Optional[str]
  50. class CommandRunner:
  51. class Exception(Exception):
  52. def __init__(self, message, returncode, stage, stdout, stderr):
  53. super().__init__(message)
  54. self.returncode = returncode
  55. self.stage = stage
  56. self.stdout = stdout
  57. self.stderr = stderr
  58. @staticmethod
  59. def run_or_fail(command: list[str], stage: str, *args, **kwargs):
  60. if DRY_RUN and command[0] == "gh":
  61. command.insert(0, "echo")
  62. result = subprocess.run(command, *args, capture_output=True, **kwargs)
  63. if result.returncode != 0:
  64. raise CommandRunner.Exception(
  65. f"{stage} command failed with exit code {result.returncode}", returncode=result.returncode,
  66. stage=stage,
  67. stdout=result.stdout.decode("utf-8"),
  68. stderr=result.stderr.decode("utf-8")
  69. )
  70. return result
  71. class DependencyStore:
  72. store: DependencyYAML = {
  73. "dependencies": {}
  74. }
  75. @staticmethod
  76. def set(data: DependencyYAML):
  77. DependencyStore.store = data
  78. @staticmethod
  79. def update_dependency_version(path: str, version: str) -> DependencyYAML:
  80. with CodeTimer(f"store deepcopy: {path}"):
  81. store_copy = deepcopy(DependencyStore.store)
  82. dependency = store_copy["dependencies"].get(path, {})
  83. dependency["version"] = version
  84. store_copy["dependencies"][path] = dependency
  85. return store_copy
  86. @staticmethod
  87. def write_store(file: str, data: DependencyYAML):
  88. with open(file, "w") as yaml_file:
  89. yaml.safe_dump(data, yaml_file, sort_keys=False)
  90. class Dependency:
  91. def __init__(self, path: str, values: DependencyDict):
  92. self.path = path
  93. self.values = values
  94. self.name: str = ""
  95. self.desc: str = ""
  96. self.kind: str = ""
  97. match path.split("/"):
  98. case ["plugins", name]:
  99. self.name = name
  100. self.kind = "plugin"
  101. self.desc = f"{name} plugin"
  102. case ["themes", name]:
  103. self.name = name.replace(".zsh-theme", "")
  104. self.kind = "theme"
  105. self.desc = f"{self.name} theme"
  106. case _:
  107. self.name = self.desc = path
  108. def __str__(self):
  109. output: str = ""
  110. for key in DependencyDict.__dict__['__annotations__'].keys():
  111. if key not in self.values:
  112. output += f"{key}: None\n"
  113. continue
  114. value = self.values[key]
  115. if "\n" not in value:
  116. output += f"{key}: {value}\n"
  117. else:
  118. output += f"{key}:\n "
  119. output += value.replace("\n", "\n ", value.count("\n") - 1)
  120. return output
  121. def update_or_notify(self):
  122. # Print dependency settings
  123. print(f"Processing {self.desc}...", file=sys.stderr)
  124. print(self, file=sys.stderr)
  125. # Check for updates
  126. repo = self.values["repo"]
  127. remote_branch = self.values["branch"]
  128. version = self.values["version"]
  129. is_tag = version.startswith("tag:")
  130. try:
  131. with CodeTimer(f"update check: {repo}"):
  132. if is_tag:
  133. status = GitHub.check_newer_tag(repo, version.replace("tag:", ""))
  134. else:
  135. status = GitHub.check_updates(repo, remote_branch, version)
  136. if status["has_updates"]:
  137. short_sha = status["head_ref"][:8]
  138. new_version = status["version"] if is_tag else short_sha
  139. try:
  140. # Create new branch
  141. branch = Git.create_branch(self.path, new_version)
  142. # Update dependencies.yml file
  143. self.__update_yaml(f"tag:{new_version}" if is_tag else status["version"])
  144. # Update dependency files
  145. self.__apply_upstream_changes()
  146. # Add all changes and commit
  147. Git.add_and_commit(self.name, short_sha)
  148. # Push changes to remote
  149. Git.push(branch)
  150. # Create GitHub PR
  151. GitHub.create_pr(
  152. branch,
  153. f"feat({self.name}): update to version {new_version}",
  154. f"""## Description
  155. Update for **{self.desc}**: update to version [{new_version}]({status['head_url']}).
  156. Check out the [list of changes]({status['compare_url']}).
  157. """
  158. )
  159. # Clean up repository
  160. Git.clean_repo()
  161. except (CommandRunner.Exception, shutil.Error) as e:
  162. # Handle exception on automatic update
  163. match type(e):
  164. case CommandRunner.Exception:
  165. # Print error message
  166. print(f"Error running {e.stage} command: {e.returncode}", file=sys.stderr)
  167. print(e.stderr, file=sys.stderr)
  168. case shutil.Error:
  169. print(f"Error copying files: {e}", file=sys.stderr)
  170. try:
  171. Git.clean_repo()
  172. except CommandRunner.Exception as e:
  173. print(f"Error reverting repository to clean state: {e}", file=sys.stderr)
  174. sys.exit(1)
  175. # Create a GitHub issue to notify maintainer
  176. title = f"{self.path}: update to {new_version}"
  177. body = (
  178. f"""## Description
  179. There is a new version of `{self.name}` {self.kind} available.
  180. New version: [{new_version}]({status['head_url']})
  181. Check out the [list of changes]({status['compare_url']}).
  182. """
  183. )
  184. print(f"Creating GitHub issue", file=sys.stderr)
  185. print(f"{title}\n\n{body}", file=sys.stderr)
  186. GitHub.create_issue(title, body)
  187. except Exception as e:
  188. print(e, file=sys.stderr)
  189. def __update_yaml(self, new_version: str) -> None:
  190. dep_yaml = DependencyStore.update_dependency_version(self.path, new_version)
  191. DependencyStore.write_store(DEPS_YAML_FILE, dep_yaml)
  192. def __apply_upstream_changes(self) -> None:
  193. # Patterns to ignore in copying files from upstream repo
  194. GLOBAL_IGNORE = [
  195. ".git",
  196. ".github",
  197. ".gitignore"
  198. ]
  199. path = os.path.abspath(self.path)
  200. precopy = self.values.get("precopy")
  201. postcopy = self.values.get("postcopy")
  202. repo = self.values["repo"]
  203. branch = self.values["branch"]
  204. remote_url = f"https://github.com/{repo}.git"
  205. repo_dir = os.path.join(TMP_DIR, repo)
  206. # Clone repository
  207. Git.clone(remote_url, branch, repo_dir, reclone=True)
  208. # Run precopy on tmp repo
  209. if precopy is not None:
  210. print("Running precopy script:", end="\n ", file=sys.stderr)
  211. print(precopy.replace("\n", "\n ", precopy.count("\n") - 1), file=sys.stderr)
  212. CommandRunner.run_or_fail(["bash", "-c", precopy], cwd=repo_dir, stage="Precopy")
  213. # Copy files from upstream repo
  214. print(f"Copying files from {repo_dir} to {path}", file=sys.stderr)
  215. shutil.copytree(repo_dir, path, dirs_exist_ok=True, ignore=shutil.ignore_patterns(*GLOBAL_IGNORE))
  216. # Run postcopy on our repository
  217. if postcopy is not None:
  218. print("Running postcopy script:", end="\n ", file=sys.stderr)
  219. print(postcopy.replace("\n", "\n ", postcopy.count("\n") - 1), file=sys.stderr)
  220. CommandRunner.run_or_fail(["bash", "-c", postcopy], cwd=path, stage="Postcopy")
  221. class Git:
  222. default_branch = "master"
  223. @staticmethod
  224. def clone(remote_url: str, branch: str, repo_dir: str, reclone=False):
  225. # If repo needs to be fresh
  226. if reclone and os.path.exists(repo_dir):
  227. shutil.rmtree(repo_dir)
  228. # Clone repo in tmp directory and checkout branch
  229. if not os.path.exists(repo_dir):
  230. print(f"Cloning {remote_url} to {repo_dir} and checking out {branch}", file=sys.stderr)
  231. CommandRunner.run_or_fail(["git", "clone", "--depth=1", "-b", branch, remote_url, repo_dir], stage="Clone")
  232. @staticmethod
  233. def create_branch(path: str, version: str):
  234. # Get current branch name
  235. result = CommandRunner.run_or_fail(["git", "rev-parse", "--abbrev-ref", "HEAD"], stage="GetDefaultBranch")
  236. Git.default_branch = result.stdout.decode("utf-8").strip()
  237. # Create new branch and return created branch name
  238. branch_name = f"update/{path}/{version}"
  239. CommandRunner.run_or_fail(["git", "checkout", "-b", branch_name], stage="CreateBranch")
  240. return branch_name
  241. @staticmethod
  242. def add_and_commit(scope: str, version: str):
  243. user_name = os.environ.get("GIT_APP_NAME")
  244. user_email = os.environ.get("GIT_APP_EMAIL")
  245. # Add all files to git staging
  246. CommandRunner.run_or_fail(["git", "add", "-A", "-v"], stage="AddFiles")
  247. # Reset environment and git config
  248. clean_env = os.environ.copy()
  249. clean_env["LANG"]="C.UTF-8"
  250. clean_env["GIT_CONFIG_GLOBAL"]="/dev/null"
  251. clean_env["GIT_CONFIG_NOSYSTEM"]="1"
  252. # Commit with settings above
  253. CommandRunner.run_or_fail([
  254. "git",
  255. "-c", f"user.name={user_name}",
  256. "-c", f"user.email={user_email}",
  257. "commit",
  258. "-m", f"feat({scope}): update to {version}"
  259. ], stage="CreateCommit", env=clean_env)
  260. @staticmethod
  261. def push(branch: str):
  262. CommandRunner.run_or_fail(["git", "push", "-u", "origin", branch], stage="PushBranch")
  263. @staticmethod
  264. def clean_repo():
  265. CommandRunner.run_or_fail(["git", "reset", "--hard", "HEAD"], stage="ResetRepository")
  266. CommandRunner.run_or_fail(["git", "checkout", Git.default_branch], stage="CheckoutDefaultBranch")
  267. class GitHub:
  268. @staticmethod
  269. def check_newer_tag(repo, current_tag) -> UpdateStatus:
  270. # GET /repos/:owner/:repo/git/refs/tags
  271. url = f"https://api.github.com/repos/{repo}/git/refs/tags"
  272. # Send a GET request to the GitHub API
  273. response = requests.get(url)
  274. # If the request was successful
  275. if response.status_code == 200:
  276. # Parse the JSON response
  277. data = response.json()
  278. if len(data) == 0:
  279. return {
  280. "has_updates": False,
  281. }
  282. latest_ref = data[-1]
  283. latest_tag = latest_ref["ref"].replace("refs/tags/", "")
  284. if latest_tag == current_tag:
  285. return {
  286. "has_updates": False,
  287. }
  288. return {
  289. "has_updates": True,
  290. "version": latest_tag,
  291. "compare_url": f"https://github.com/{repo}/compare/{current_tag}...{latest_tag}",
  292. "head_ref": latest_ref["object"]["sha"],
  293. "head_url": f"https://github.com/{repo}/releases/tag/{latest_tag}",
  294. }
  295. else:
  296. # If the request was not successful, raise an exception
  297. raise Exception(f"GitHub API request failed with status code {response.status_code}: {response.json()}")
  298. @staticmethod
  299. def check_updates(repo, branch, version) -> UpdateStatus:
  300. # TODO: add support for semver updating (based on tags)
  301. # Check if upstream github repo has a new version
  302. # GitHub API URL for comparing two commits
  303. url = f"https://api.github.com/repos/{repo}/compare/{version}...{branch}"
  304. # Send a GET request to the GitHub API
  305. response = requests.get(url)
  306. # If the request was successful
  307. if response.status_code == 200:
  308. # Parse the JSON response
  309. data = response.json()
  310. # If the base is behind the head, there is a newer version
  311. has_updates = data["status"] != "identical"
  312. if not has_updates:
  313. return {
  314. "has_updates": False,
  315. }
  316. return {
  317. "has_updates": data["status"] != "identical",
  318. "version": data["commits"][-1]["sha"],
  319. "compare_url": data["permalink_url"],
  320. "head_ref": data["commits"][-1]["sha"],
  321. "head_url": data["commits"][-1]["html_url"]
  322. }
  323. else:
  324. # If the request was not successful, raise an exception
  325. raise Exception(f"GitHub API request failed with status code {response.status_code}: {response.json()}")
  326. @staticmethod
  327. def create_issue(title: str, body: str) -> None:
  328. cmd = [
  329. "gh",
  330. "issue",
  331. "create",
  332. "-t", title,
  333. "-b", body
  334. ]
  335. CommandRunner.run_or_fail(cmd, stage="CreateIssue")
  336. @staticmethod
  337. def create_pr(branch: str, title: str, body: str) -> None:
  338. cmd = [
  339. "gh",
  340. "pr",
  341. "create",
  342. "-B", Git.default_branch,
  343. "-H", branch,
  344. "-t", title,
  345. "-b", body
  346. ]
  347. CommandRunner.run_or_fail(cmd, stage="CreatePullRequest")
  348. def main():
  349. # Load the YAML file
  350. with open(DEPS_YAML_FILE, "r") as yaml_file:
  351. data: DependencyYAML = yaml.safe_load(yaml_file)
  352. if "dependencies" not in data:
  353. raise Exception(f"dependencies.yml not properly formatted")
  354. # Cache YAML version
  355. DependencyStore.set(data)
  356. dependencies = data["dependencies"]
  357. for path in dependencies:
  358. dependency = Dependency(path, dependencies[path])
  359. dependency.update_or_notify()
  360. if __name__ == "__main__":
  361. main()