Coverage for src\packagelister\packagelister.py: 99%
98 statements
« prev ^ index » next coverage.py v7.2.2, created at 2024-01-12 13:53 -0600
« prev ^ index » next coverage.py v7.2.2, created at 2024-01-12 13:53 -0600
1import ast
2import importlib.metadata
3import sys
4from dataclasses import dataclass
6from pathier import Pathier, Pathish
7from printbuddies import ProgBar
8from typing_extensions import Self
10packages_distributions = importlib.metadata.packages_distributions()
13def is_builtin(package_name: str) -> bool:
14 """Returns whether `package_name` is a standard library module or not."""
15 return package_name in sys.stdlib_module_names
18@dataclass
19class Package:
20 """Dataclass representing an imported package.
22 #### Fields:
23 * `name: str`
24 * `distribution_name: str | None` - the name used to `pip install`, sometimes this differs from `name`
25 * `version: str | None`
26 * `builtin: bool` - whether this is a standard library package or not"""
28 name: str
29 distribution_name: str | None
30 version: str | None
31 builtin: bool
33 def format_requirement(self, version_specifier: str):
34 """Returns a string of the form `{self.distribution_name}{version_specifier}{self.version}`.
35 e.g for this package: `"packagelister>=2.0.0"`"""
36 return f"{self.distribution_name}{version_specifier}{self.version}"
38 @classmethod
39 def from_name(cls, package_name: str) -> Self:
40 """Returns a `Package` instance from the package name.
42 Will attempt to determine the other class fields."""
43 distributions = packages_distributions.get(package_name)
44 if distributions:
45 distribution_name = distributions[0]
46 version = importlib.metadata.version(distribution_name)
47 else:
48 distribution_name = None
49 version = None
50 return cls(package_name, distribution_name, version, is_builtin(package_name))
53class PackageList(list[Package]):
54 """A subclass of `list` to add convenience methods when working with a list of `packagelister.Package` objects."""
56 @property
57 def names(self) -> list[str]:
58 """Returns a list of `Package.name` strings."""
59 return [package.name for package in self]
61 @property
62 def third_party(self) -> Self:
63 """Returns a `PackageList` instance for the third party packages in this list."""
64 return self.__class__(
65 [
66 package
67 for package in self
68 if not package.builtin and package.distribution_name
69 ]
70 )
72 @property
73 def builtin(self) -> Self:
74 """Returns a `PackageList` instance for the standard library packages in this list."""
75 return self.__class__([package for package in self if package.builtin])
78@dataclass
79class File:
80 """Dataclass representing a scanned file and its list of imported packages.
82 #### Fields:
83 * `path: Pathier` - Pathier object representing the path to this file
84 * `packages: packagelister.PackageList` - List of Package objects imported by this file
85 """
87 path: Pathier
88 packages: PackageList
91@dataclass
92class Project:
93 """Dataclass representing a directory that's had its files scanned for imports.
95 #### Fields:
96 * `files: list[packagelister.File]`"""
98 files: list[File]
100 @property
101 def packages(self) -> PackageList:
102 """Returns a `packagelister.PackageList` object for this instance with no duplicates."""
103 packages = []
104 for file in self.files:
105 for package in file.packages:
106 if package not in packages:
107 packages.append(package)
108 return PackageList(sorted(packages, key=lambda p: p.name))
110 @property
111 def requirements(self) -> PackageList:
112 """Returns a `packagelister.PackageList` object of third party packages used by this project."""
113 return self.packages.third_party
115 def get_formatted_requirements(
116 self, version_specifier: str | None = None
117 ) -> list[str]:
118 """Returns a list of formatted requirements (third party packages) using `version_specifier` (`==`,`>=`, `<=`, etc.).
120 If no `version_specifier` is given, the returned list will just be package names.
121 """
122 return [
123 requirement.format_requirement(version_specifier)
124 if version_specifier
125 else requirement.distribution_name or requirement.name
126 for requirement in self.requirements
127 ]
129 def get_files_by_package(self) -> dict[str, list[Pathier]]:
130 """Returns a dictionary where the keys are package names and the values are lists of files that import the package."""
131 files_by_package = {}
132 for package in self.packages:
133 for file in self.files:
134 name = package.name
135 if name in file.packages.names:
136 if name not in files_by_package:
137 files_by_package[name] = [file.path]
138 else:
139 files_by_package[name].append(file.path)
140 return files_by_package
143def get_package_names_from_source(source: str) -> list[str]:
144 """Scan `source` and extract the names of imported packages/modules."""
145 tree = ast.parse(source)
146 packages = []
147 for node in ast.walk(tree):
148 type_ = type(node)
149 package = ""
150 if type_ == ast.Import:
151 package = node.names[0].name # type: ignore
152 elif type_ == ast.ImportFrom:
153 package = node.module # type: ignore
154 if package:
155 if "." in package:
156 package = package[: package.find(".")]
157 packages.append(package)
158 return sorted(list(set(packages)))
161def scan_file(file: Pathish) -> File:
162 """Scan `file` for imports and return a `packagelister.File` instance."""
163 file = Pathier(file) if not type(file) == Pathier else file
164 source = file.read_text(encoding="utf-8")
165 packages = get_package_names_from_source(source)
166 used_packages = PackageList(
167 [
168 Package.from_name(package)
169 for package in packages
170 if package
171 not in file.parts # don't want to pick up modules in the scanned directory
172 ]
173 )
174 return File(file, used_packages)
177def scan_dir(path: Pathish, quiet: bool = False) -> Project:
178 """Recursively scan `*.py` files in `path` for imports and return a `packagelister.Project` instance.
180 Set `quiet` to `False` to prevent printing."""
181 path = Pathier(path) if not type(path) == Pathier else path
182 files = list(path.rglob("*.py"))
183 if quiet:
184 project = Project([scan_file(file) for file in files])
185 else:
186 num_files = len(files)
187 print(f"Scanning {num_files} in {path}...")
188 with ProgBar(len(files), width_ratio=0.3) as bar:
189 project = Project(
190 [bar.display(return_object=scan_file(file)) for file in files]
191 )
192 return project