Coverage for src\packagelister\packagelister.py: 99%

98 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2024-01-12 13:53 -0600

1import ast 

2import importlib.metadata 

3import sys 

4from dataclasses import dataclass 

5 

6from pathier import Pathier, Pathish 

7from printbuddies import ProgBar 

8from typing_extensions import Self 

9 

10packages_distributions = importlib.metadata.packages_distributions() 

11 

12 

13def is_builtin(package_name: str) -> bool: 

14 """Returns whether `package_name` is a standard library module or not.""" 

15 return package_name in sys.stdlib_module_names 

16 

17 

18@dataclass 

19class Package: 

20 """Dataclass representing an imported package. 

21 

22 #### Fields: 

23 * `name: str` 

24 * `distribution_name: str | None` - the name used to `pip install`, sometimes this differs from `name` 

25 * `version: str | None` 

26 * `builtin: bool` - whether this is a standard library package or not""" 

27 

28 name: str 

29 distribution_name: str | None 

30 version: str | None 

31 builtin: bool 

32 

33 def format_requirement(self, version_specifier: str): 

34 """Returns a string of the form `{self.distribution_name}{version_specifier}{self.version}`. 

35 e.g for this package: `"packagelister>=2.0.0"`""" 

36 return f"{self.distribution_name}{version_specifier}{self.version}" 

37 

38 @classmethod 

39 def from_name(cls, package_name: str) -> Self: 

40 """Returns a `Package` instance from the package name. 

41 

42 Will attempt to determine the other class fields.""" 

43 distributions = packages_distributions.get(package_name) 

44 if distributions: 

45 distribution_name = distributions[0] 

46 version = importlib.metadata.version(distribution_name) 

47 else: 

48 distribution_name = None 

49 version = None 

50 return cls(package_name, distribution_name, version, is_builtin(package_name)) 

51 

52 

53class PackageList(list[Package]): 

54 """A subclass of `list` to add convenience methods when working with a list of `packagelister.Package` objects.""" 

55 

56 @property 

57 def names(self) -> list[str]: 

58 """Returns a list of `Package.name` strings.""" 

59 return [package.name for package in self] 

60 

61 @property 

62 def third_party(self) -> Self: 

63 """Returns a `PackageList` instance for the third party packages in this list.""" 

64 return self.__class__( 

65 [ 

66 package 

67 for package in self 

68 if not package.builtin and package.distribution_name 

69 ] 

70 ) 

71 

72 @property 

73 def builtin(self) -> Self: 

74 """Returns a `PackageList` instance for the standard library packages in this list.""" 

75 return self.__class__([package for package in self if package.builtin]) 

76 

77 

78@dataclass 

79class File: 

80 """Dataclass representing a scanned file and its list of imported packages. 

81 

82 #### Fields: 

83 * `path: Pathier` - Pathier object representing the path to this file 

84 * `packages: packagelister.PackageList` - List of Package objects imported by this file 

85 """ 

86 

87 path: Pathier 

88 packages: PackageList 

89 

90 

91@dataclass 

92class Project: 

93 """Dataclass representing a directory that's had its files scanned for imports. 

94 

95 #### Fields: 

96 * `files: list[packagelister.File]`""" 

97 

98 files: list[File] 

99 

100 @property 

101 def packages(self) -> PackageList: 

102 """Returns a `packagelister.PackageList` object for this instance with no duplicates.""" 

103 packages = [] 

104 for file in self.files: 

105 for package in file.packages: 

106 if package not in packages: 

107 packages.append(package) 

108 return PackageList(sorted(packages, key=lambda p: p.name)) 

109 

110 @property 

111 def requirements(self) -> PackageList: 

112 """Returns a `packagelister.PackageList` object of third party packages used by this project.""" 

113 return self.packages.third_party 

114 

115 def get_formatted_requirements( 

116 self, version_specifier: str | None = None 

117 ) -> list[str]: 

118 """Returns a list of formatted requirements (third party packages) using `version_specifier` (`==`,`>=`, `<=`, etc.). 

119 

120 If no `version_specifier` is given, the returned list will just be package names. 

121 """ 

122 return [ 

123 requirement.format_requirement(version_specifier) 

124 if version_specifier 

125 else requirement.distribution_name or requirement.name 

126 for requirement in self.requirements 

127 ] 

128 

129 def get_files_by_package(self) -> dict[str, list[Pathier]]: 

130 """Returns a dictionary where the keys are package names and the values are lists of files that import the package.""" 

131 files_by_package = {} 

132 for package in self.packages: 

133 for file in self.files: 

134 name = package.name 

135 if name in file.packages.names: 

136 if name not in files_by_package: 

137 files_by_package[name] = [file.path] 

138 else: 

139 files_by_package[name].append(file.path) 

140 return files_by_package 

141 

142 

143def get_package_names_from_source(source: str) -> list[str]: 

144 """Scan `source` and extract the names of imported packages/modules.""" 

145 tree = ast.parse(source) 

146 packages = [] 

147 for node in ast.walk(tree): 

148 type_ = type(node) 

149 package = "" 

150 if type_ == ast.Import: 

151 package = node.names[0].name # type: ignore 

152 elif type_ == ast.ImportFrom: 

153 package = node.module # type: ignore 

154 if package: 

155 if "." in package: 

156 package = package[: package.find(".")] 

157 packages.append(package) 

158 return sorted(list(set(packages))) 

159 

160 

161def scan_file(file: Pathish) -> File: 

162 """Scan `file` for imports and return a `packagelister.File` instance.""" 

163 file = Pathier(file) if not type(file) == Pathier else file 

164 source = file.read_text(encoding="utf-8") 

165 packages = get_package_names_from_source(source) 

166 used_packages = PackageList( 

167 [ 

168 Package.from_name(package) 

169 for package in packages 

170 if package 

171 not in file.parts # don't want to pick up modules in the scanned directory 

172 ] 

173 ) 

174 return File(file, used_packages) 

175 

176 

177def scan_dir(path: Pathish, quiet: bool = False) -> Project: 

178 """Recursively scan `*.py` files in `path` for imports and return a `packagelister.Project` instance. 

179 

180 Set `quiet` to `False` to prevent printing.""" 

181 path = Pathier(path) if not type(path) == Pathier else path 

182 files = list(path.rglob("*.py")) 

183 if quiet: 

184 project = Project([scan_file(file) for file in files]) 

185 else: 

186 num_files = len(files) 

187 print(f"Scanning {num_files} in {path}...") 

188 with ProgBar(len(files), width_ratio=0.3) as bar: 

189 project = Project( 

190 [bar.display(return_object=scan_file(file)) for file in files] 

191 ) 

192 return project