Coverage for src/receptiviti/request.py: 86%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-01 10:33 -0500

1"""Make requests to the API.""" 

2 

3import os 

4import re 

5import shutil 

6from glob import glob 

7from importlib.util import find_spec 

8from math import ceil 

9from multiprocessing import current_process 

10from tempfile import gettempdir 

11from time import perf_counter, time 

12from typing import List, Union 

13 

14import pandas 

15import pyarrow.dataset 

16 

17from receptiviti.frameworks import frameworks as get_frameworks 

18from receptiviti.manage_request import _get_writer, _manage_request 

19from receptiviti.norming import norming 

20from receptiviti.readin_env import readin_env 

21 

22CACHE = gettempdir() + "/receptiviti_cache/" 

23REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

24 

25 

26def request( 

27 text: Union[str, List[str], pandas.DataFrame, None] = None, 

28 output: Union[str, None] = None, 

29 ids: Union[str, List[str], List[int], None] = None, 

30 text_column: Union[str, None] = None, 

31 id_column: Union[str, None] = None, 

32 files: Union[List[str], None] = None, 

33 directory: Union[str, None] = None, 

34 file_type: str = "txt", 

35 encoding: Union[str, None] = None, 

36 return_text=False, 

37 context="written", 

38 custom_context: Union[str, bool] = False, 

39 api_args: Union[dict, None] = None, 

40 frameworks: Union[str, List[str], None] = None, 

41 framework_prefix: Union[bool, None] = None, 

42 bundle_size=1000, 

43 bundle_byte_limit=75e5, 

44 collapse_lines=False, 

45 retry_limit=50, 

46 clear_cache=False, 

47 request_cache=True, 

48 cores=1, 

49 collect_results=True, 

50 in_memory: Union[bool, None] = None, 

51 verbose=False, 

52 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

53 overwrite=False, 

54 make_request=True, 

55 text_as_paths=False, 

56 dotenv: Union[bool, str] = True, 

57 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""), 

58 cache_overwrite=False, 

59 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""), 

60 key=os.getenv("RECEPTIVITI_KEY", ""), 

61 secret=os.getenv("RECEPTIVITI_SECRET", ""), 

62 url=os.getenv("RECEPTIVITI_URL", ""), 

63 version=os.getenv("RECEPTIVITI_VERSION", ""), 

64 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""), 

65) -> pandas.DataFrame | None: 

66 """ 

67 Send texts to be scored by the API. 

68 

69 Args: 

70 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of 

71 strings containing the text itself, or the path to a file from which to read in text. 

72 If a DataFrame, `text_column` is used to extract such a vector. A string may also 

73 represent a directory in which to search for files. To best ensure paths are not 

74 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter 

75 a directory path, or `files` to enter a vector of file paths. 

76 output (str): Path to a file to write results to. 

77 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text` 

78 containing IDs. 

79 text_column (str): Column name in `text` containing text. 

80 id_column (str): Column name in `text` containing IDs. 

81 files (list[str]): Vector of file paths, as alternate entry to `text`. 

82 directory (str): A directory path to search for files in, as alternate entry to `text`. 

83 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`). 

84 encoding (str | None): Encoding of file(s) to be read in; one of the 

85 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings). 

86 If this is `None` (default), encoding will be predicted for each file, but this can 

87 potentially fail, resulting in mis-encoded characters. For best (and fastest) results, 

88 specify encoding. 

89 return_text (bool): If `True`, will include a `text` column in the output with the 

90 original text. 

91 context (str): Name of the analysis context. 

92 custom_context (str | bool): Name of a custom context (as listed by `receptiviti.norming`), 

93 or `True` if `context` is the name of a custom context. 

94 api_args (dict): Additional arguments to include in the request. 

95 frameworks (str | list): One or more names of frameworks to request. Note that this 

96 changes the results from the API, so it will invalidate any cached results 

97 without the same set of frameworks. 

98 framework_prefix (bool): If `False`, will drop framework prefix from column names. 

99 If one framework is selected, will default to `False`. 

100 bundle_size (int): Maximum number of texts per bundle. 

101 bundle_byte_limit (float): Maximum byte size of each bundle. 

102 collapse_lines (bool): If `True`, will treat files as containing single texts, and 

103 collapse multiple lines. 

104 retry_limit (int): Number of times to retry a failed request. 

105 clear_cache (bool): If `True`, will delete the `cache` before processing. 

106 request_cache (bool): If `False`, will not temporarily save raw requests for reuse 

107 within a day. 

108 cores (int): Number of CPU cores to use when processing multiple bundles. 

109 collect_results (bool): If `False`, will not retain bundle results in memory for return. 

110 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when 

111 processed. Defaults to `True` when processing in parallel. 

112 verbose (bool): If `True`, will print status messages and preserve the progress bar. 

113 progress_bar (str | bool): If `False`, will not display a progress bar. 

114 overwrite (bool): If `True`, will overwrite an existing `output` file. 

115 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths. 

116 Otherwise, this will be detected. 

117 dotenv (bool | str): Path to a .env file to read environment variables from. By default, 

118 will for a file in the current directory or `~/Documents`. 

119 Passed to `readin_env` as `path`. 

120 cache (bool | str): Path to a cache directory, or `True` to use the default directory. 

121 The cache is an Arrow dataset, and so requires the `pyarrow` package. 

122 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts, 

123 but will store results in the cache (unlike `cache = False`). 

124 cache_format (str): File format of the cache, of available Arrow formats. 

125 key (str): Your API key. 

126 secret (str): Your API secret. 

127 url (str): The URL of the API; defaults to `https://api.receptiviti.com`. 

128 version (str): Version of the API; defaults to `v1`. 

129 endpoint (str): Endpoint of the API; defaults to `framework`. 

130 

131 Returns: 

132 Scores associated with each input text. 

133 

134 Examples: 

135 ``` 

136 # score a single text 

137 single = receptiviti.request("a text to score") 

138 

139 # score multiple texts, and write results to a file 

140 multi = receptiviti.request(["first text to score", "second text"], "filename.csv") 

141 

142 # score texts in separate files 

143 ## defaults to look for .txt files 

144 file_results = receptiviti.request(directory = "./path/to/txt_folder") 

145 

146 ## could be .csv 

147 file_results = receptiviti.request( 

148 directory = "./path/to/csv_folder", 

149 text_column = "text", file_type = "csv" 

150 ) 

151 

152 # score texts in a single file 

153 results = receptiviti.request("./path/to/file.csv", text_column = "text") 

154 ``` 

155 

156 Cache: 

157 If `cache` is specified, results for unique texts are saved in an Arrow database 

158 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with 

159 subsequent requests. This ensures that the exact same texts are not re-sent to the API. 

160 This does, however, add some processing time and disc space usage. 

161 

162 If `cache` if `True`, a default directory (`receptiviti_cache`) will be 

163 looked for in the system's temporary directory (`tempfile.gettempdir()`). 

164 

165 The primary cache is checked when each bundle is processed, and existing results are 

166 loaded at that time. When processing many bundles in parallel, and many results have 

167 been cached, this can cause the system to freeze and potentially crash. 

168 To avoid this, limit the number of cores, or disable parallel processing. 

169 

170 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be 

171 used to adjust the format of the cache. 

172 

173 You can use the cache independently with 

174 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`. 

175 

176 You can also set the `clear_cache` argument to `True` to clear the cache before it is used 

177 again, which may be useful if the cache has gotten big, or you know new results will be 

178 returned. 

179 

180 Even if a cached result exists, it will be reprocessed if it does not have all of the 

181 variables of new results, but this depends on there being at least 1 uncached result. If, 

182 for instance, you add a framework to your account and want to reprocess a previously 

183 processed set of texts, you would need to first clear the cache. 

184 

185 Either way, duplicated texts within the same call will only be sent once. 

186 

187 The `request_cache` argument controls a more temporary cache of each bundle request. This 

188 is cleared after a day. You might want to set this to `False` if a new framework becomes 

189 available on your account and you want to process a set of text you re-processed recently. 

190 

191 Another temporary cache is made when `in_memory` is `False`, which is the default when 

192 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a 

193 temporary directory that contains a file for each unique bundle, which is read in as needed 

194 by the parallel workers. 

195 

196 Parallelization: 

197 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents 

198 a single request to the API, which is why they are limited to 1000 texts and a total size 

199 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are 

200 processed by multiple cores. 

201 

202 If you have texts spread across multiple files, they can be most efficiently processed in 

203 parallel if each file contains a single text (potentially collapsed from multiple lines). 

204 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be 

205 read in before bundling in order to ensure bundles are under the length limit. 

206 

207 If you are calling this function from a script, parallelization will involve rerunning 

208 that script in each process, so anything you don't want rerun should be protected by 

209 a check that `__name__` equals `"__main__"` 

210 (placed within an `if __name__ == "__main__":` clause). 

211 """ 

212 if cores > 1 and current_process().name != "MainProcess": 

213 return None 

214 if output is not None and os.path.isfile(output) and not overwrite: 

215 msg = "`output` file already exists; use `overwrite=True` to overwrite it" 

216 raise RuntimeError(msg) 

217 start_time = perf_counter() 

218 

219 if dotenv: 

220 readin_env(dotenv if isinstance(dotenv, str) else ".") 

221 dotenv = False 

222 

223 # check norming context 

224 if isinstance(custom_context, str): 

225 context = custom_context 

226 custom_context = True 

227 if context != "written": 

228 if verbose: 

229 print(f"retrieving norming contexts ({perf_counter() - start_time:.4f})") 

230 available_contexts: List[str] = norming(name_only=True, url=url, key=key, secret=secret, verbose=False) 

231 if ("custom/" + context if custom_context else context) not in available_contexts: 

232 msg = f"norming context {context} is not on record or is not completed" 

233 raise RuntimeError(msg) 

234 

235 # check frameworks 

236 if frameworks and version and "2" in version: 

237 if not api_args: 

238 api_args = {} 

239 if isinstance(frameworks, str): 

240 frameworks = [frameworks] 

241 api_args["frameworks"] = [f for f in frameworks if f != "summary"] 

242 if api_args and "frameworks" in api_args: 

243 arg_frameworks: List[str] = ( 

244 api_args["frameworks"].split(",") if isinstance(api_args["frameworks"], str) else api_args["frameworks"] 

245 ) 

246 available_frameworks = get_frameworks(url=url, key=key, secret=secret) 

247 for f in arg_frameworks: 

248 if f not in available_frameworks: 

249 msg = f"requested framework is not available to your account: {f}" 

250 raise RuntimeError(msg) 

251 if isinstance(api_args["frameworks"], list): 

252 api_args["frameworks"] = ",".join(api_args["frameworks"]) 

253 

254 if isinstance(cache, str) and cache: 

255 if find_spec("pyarrow") is None: 

256 msg = "install the `pyarrow` package to use the cache" 

257 raise RuntimeError(msg) 

258 if clear_cache and os.path.exists(cache): 

259 shutil.rmtree(cache, True) 

260 os.makedirs(cache, exist_ok=True) 

261 if not cache_format: 

262 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet") 

263 if cache_format not in ["parquet", "feather"]: 

264 msg = "`cache_format` must be `parquet` or `feather`" 

265 raise RuntimeError(msg) 

266 else: 

267 cache = "" 

268 

269 data, res, id_specified = _manage_request( 

270 text=text, 

271 ids=ids, 

272 text_column=text_column, 

273 id_column=id_column, 

274 files=files, 

275 directory=directory, 

276 file_type=file_type, 

277 encoding=encoding, 

278 context=f"custom/{context}" if custom_context else context, 

279 api_args=api_args, 

280 bundle_size=bundle_size, 

281 bundle_byte_limit=bundle_byte_limit, 

282 collapse_lines=collapse_lines, 

283 retry_limit=retry_limit, 

284 request_cache=request_cache, 

285 cores=cores, 

286 collect_results=collect_results, 

287 in_memory=in_memory, 

288 verbose=verbose, 

289 progress_bar=progress_bar, 

290 make_request=make_request, 

291 text_as_paths=text_as_paths, 

292 dotenv=dotenv, 

293 cache=cache, 

294 cache_overwrite=cache_overwrite, 

295 cache_format=cache_format, 

296 key=key, 

297 secret=secret, 

298 url=url, 

299 version=version, 

300 endpoint=endpoint, 

301 ) 

302 

303 # finalize 

304 if collect_results and (res is None or not res.shape[0]): 

305 msg = "no results" 

306 raise RuntimeError(msg) 

307 if cache: 

308 writer = _get_writer(cache_format) 

309 for bin_dir in glob(cache + "/bin=*/"): 

310 _defragment_bin(bin_dir, cache_format, writer) 

311 if not collect_results: 

312 if verbose: 

313 print(f"done ({perf_counter() - start_time:.4f})") 

314 return None 

315 if verbose: 

316 print(f"preparing output ({perf_counter() - start_time:.4f})") 

317 data.set_index("id", inplace=True) 

318 res.set_index("id", inplace=True) 

319 if len(res) != len(data): 

320 res = res.join(data["text"]) 

321 data_absent = data.loc[list(set(data.index).difference(res.index))] 

322 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])] 

323 if data.size: 

324 res = res.reset_index() 

325 res.set_index("text", inplace=True) 

326 data_dupes = res.loc[data_absent["text"]] 

327 data_dupes["id"] = data_absent.index.to_list() 

328 res = pandas.concat([res, data_dupes]) 

329 res.reset_index(inplace=True, drop=True) 

330 res.set_index("id", inplace=True) 

331 res = res.join(data["text"], how="right") 

332 if not return_text: 

333 res.drop("text", axis=1, inplace=True) 

334 res = res.reset_index() 

335 

336 if output is not None: 

337 if verbose: 

338 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})") 

339 res.to_csv(output, index=False) 

340 

341 drops = ["custom", "bin"] 

342 if not id_specified: 

343 drops.append("id") 

344 res.drop( 

345 list({*drops}.intersection(res.columns)), 

346 axis="columns", 

347 inplace=True, 

348 ) 

349 if frameworks is not None: 

350 if verbose: 

351 print(f"selecting frameworks ({perf_counter() - start_time:.4f})") 

352 if isinstance(frameworks, str): 

353 frameworks = [frameworks] 

354 if len(frameworks) == 1 and framework_prefix is None: 

355 framework_prefix = False 

356 select = [] 

357 if id_specified: 

358 select.append("id") 

359 if return_text: 

360 select.append("text") 

361 select.append("text_hash") 

362 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)") 

363 if isinstance(framework_prefix, bool) and not framework_prefix: 

364 prefix_pattern = re.compile("^[^.]+\\.") 

365 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns]) 

366 

367 if verbose: 

368 print(f"done ({perf_counter() - start_time:.4f})") 

369 

370 return res 

371 

372 

373def _defragment_bin(bin_dir: str, write_format: str, writer): 

374 fragments = glob(f"{bin_dir}/*.{write_format}") 

375 if len(fragments) > 1: 

376 data = pyarrow.dataset.dataset(fragments, format=write_format, exclude_invalid_files=True).to_table() 

377 nrows = data.num_rows 

378 n_chunks = max(1, ceil(nrows / 1e9)) 

379 rows_per_chunk = max(1, ceil(nrows / n_chunks)) 

380 time_id = str(ceil(time())) 

381 for chunk in range(0, n_chunks, rows_per_chunk): 

382 writer(data[chunk : (chunk + rows_per_chunk)], f"{bin_dir}/part-{time_id}-{chunk}.{write_format}") 

383 for fragment in fragments: 

384 os.unlink(fragment)