Coverage for src/receptiviti/request.py: 87%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-11 18:12 -0500

1"""Make requests to the API.""" 

2 

3import os 

4import re 

5import shutil 

6from glob import glob 

7from math import ceil 

8from multiprocessing import current_process 

9from tempfile import gettempdir 

10from time import perf_counter, time 

11from typing import List, Union 

12 

13import pandas 

14import pyarrow.dataset 

15 

16from receptiviti.frameworks import frameworks as get_frameworks 

17from receptiviti.manage_request import _get_writer, _manage_request 

18from receptiviti.norming import norming 

19from receptiviti.readin_env import readin_env 

20 

21CACHE = gettempdir() + "/receptiviti_cache/" 

22REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

23 

24 

25def request( 

26 text: Union[str, List[str], pandas.DataFrame, None] = None, 

27 output: Union[str, None] = None, 

28 ids: Union[str, List[str], List[int], None] = None, 

29 text_column: Union[str, None] = None, 

30 id_column: Union[str, None] = None, 

31 files: Union[List[str], None] = None, 

32 directory: Union[str, None] = None, 

33 file_type: str = "txt", 

34 encoding: Union[str, None] = None, 

35 return_text=False, 

36 context="written", 

37 custom_context: Union[str, bool] = False, 

38 api_args: Union[dict, None] = None, 

39 frameworks: Union[str, List[str], None] = None, 

40 framework_prefix: Union[bool, None] = None, 

41 bundle_size=1000, 

42 bundle_byte_limit=75e5, 

43 collapse_lines=False, 

44 retry_limit=50, 

45 clear_cache=False, 

46 request_cache=True, 

47 cores=1, 

48 in_memory: Union[bool, None] = None, 

49 verbose=False, 

50 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

51 overwrite=False, 

52 make_request=True, 

53 text_as_paths=False, 

54 dotenv: Union[bool, str] = True, 

55 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""), 

56 cache_overwrite=False, 

57 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""), 

58 key=os.getenv("RECEPTIVITI_KEY", ""), 

59 secret=os.getenv("RECEPTIVITI_SECRET", ""), 

60 url=os.getenv("RECEPTIVITI_URL", ""), 

61 version=os.getenv("RECEPTIVITI_VERSION", ""), 

62 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""), 

63) -> pandas.DataFrame | None: 

64 """ 

65 Send texts to be scored by the API. 

66 

67 Args: 

68 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of 

69 strings containing the text itself, or the path to a file from which to read in text. 

70 If a DataFrame, `text_column` is used to extract such a vector. A string may also 

71 represent a directory in which to search for files. To best ensure paths are not 

72 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter 

73 a directory path, or `files` to enter a vector of file paths. 

74 output (str): Path to a file to write results to. 

75 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text` 

76 containing IDs. 

77 text_column (str): Column name in `text` containing text. 

78 id_column (str): Column name in `text` containing IDs. 

79 files (list[str]): Vector of file paths, as alternate entry to `text`. 

80 directory (str): A directory path to search for files in, as alternate entry to `text`. 

81 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`). 

82 encoding (str | None): Encoding of file(s) to be read in; one of the 

83 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings). 

84 If this is `None` (default), encoding will be predicted for each file, but this can 

85 potentially fail, resulting in mis-encoded characters. For best (and fastest) results, 

86 specify encoding. 

87 return_text (bool): If `True`, will include a `text` column in the output with the 

88 original text. 

89 context (str): Name of the analysis context. 

90 custom_context (str | bool): Name of a custom context (as listed by `receptiviti.norming`), 

91 or `True` if `context` is the name of a custom context. 

92 api_args (dict): Additional arguments to include in the request. 

93 frameworks (str | list): One or more names of frameworks to request. Note that this 

94 changes the results from the API, so it will invalidate any cached results 

95 without the same set of frameworks. 

96 framework_prefix (bool): If `False`, will drop framework prefix from column names. 

97 If one framework is selected, will default to `False`. 

98 bundle_size (int): Maximum number of texts per bundle. 

99 bundle_byte_limit (float): Maximum byte size of each bundle. 

100 collapse_lines (bool): If `True`, will treat files as containing single texts, and 

101 collapse multiple lines. 

102 retry_limit (int): Number of times to retry a failed request. 

103 clear_cache (bool): If `True`, will delete the `cache` before processing. 

104 request_cache (bool): If `False`, will not temporarily save raw requests for reuse 

105 within a day. 

106 cores (int): Number of CPU cores to use when processing multiple bundles. 

107 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when 

108 processed. Defaults to `True` when processing in parallel. 

109 verbose (bool): If `True`, will print status messages and preserve the progress bar. 

110 progress_bar (str | bool): If `False`, will not display a progress bar. 

111 overwrite (bool): If `True`, will overwrite an existing `output` file. 

112 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths. 

113 Otherwise, this will be detected. 

114 dotenv (bool | str): Path to a .env file to read environment variables from. By default, 

115 will for a file in the current directory or `~/Documents`. 

116 Passed to `readin_env` as `path`. 

117 cache (bool | str): Path to a cache directory, or `True` to use the default directory. 

118 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts, 

119 but will store results in the cache (unlike `cache = False`). 

120 cache_format (str): File format of the cache, of available Arrow formats. 

121 key (str): Your API key. 

122 secret (str): Your API secret. 

123 url (str): The URL of the API; defaults to `https://api.receptiviti.com`. 

124 version (str): Version of the API; defaults to `v1`. 

125 endpoint (str): Endpoint of the API; defaults to `framework`. 

126 

127 Returns: 

128 Scores associated with each input text. 

129 

130 Examples: 

131 ``` 

132 # score a single text 

133 single = receptiviti.request("a text to score") 

134 

135 # score multiple texts, and write results to a file 

136 multi = receptiviti.request(["first text to score", "second text"], "filename.csv") 

137 

138 # score texts in separate files 

139 ## defaults to look for .txt files 

140 file_results = receptiviti.request(directory = "./path/to/txt_folder") 

141 

142 ## could be .csv 

143 file_results = receptiviti.request( 

144 directory = "./path/to/csv_folder", 

145 text_column = "text", file_type = "csv" 

146 ) 

147 

148 # score texts in a single file 

149 results = receptiviti.request("./path/to/file.csv", text_column = "text") 

150 ``` 

151 

152 Cache: 

153 If `cache` is specified, results for unique texts are saved in an Arrow database 

154 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with 

155 subsequent requests. This ensures that the exact same texts are not re-sent to the API. 

156 This does, however, add some processing time and disc space usage. 

157 

158 If `cache` if `True`, a default directory (`receptiviti_cache`) will be 

159 looked for in the system's temporary directory (`tempfile.gettempdir()`). 

160 

161 The primary cache is checked when each bundle is processed, and existing results are 

162 loaded at that time. When processing many bundles in parallel, and many results have 

163 been cached, this can cause the system to freeze and potentially crash. 

164 To avoid this, limit the number of cores, or disable parallel processing. 

165 

166 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be 

167 used to adjust the format of the cache. 

168 

169 You can use the cache independently with 

170 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`. 

171 

172 You can also set the `clear_cache` argument to `True` to clear the cache before it is used 

173 again, which may be useful if the cache has gotten big, or you know new results will be 

174 returned. 

175 

176 Even if a cached result exists, it will be reprocessed if it does not have all of the 

177 variables of new results, but this depends on there being at least 1 uncached result. If, 

178 for instance, you add a framework to your account and want to reprocess a previously 

179 processed set of texts, you would need to first clear the cache. 

180 

181 Either way, duplicated texts within the same call will only be sent once. 

182 

183 The `request_cache` argument controls a more temporary cache of each bundle request. This 

184 is cleared after a day. You might want to set this to `False` if a new framework becomes 

185 available on your account and you want to process a set of text you re-processed recently. 

186 

187 Another temporary cache is made when `in_memory` is `False`, which is the default when 

188 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a 

189 temporary directory that contains a file for each unique bundle, which is read in as needed 

190 by the parallel workers. 

191 

192 Parallelization: 

193 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents 

194 a single request to the API, which is why they are limited to 1000 texts and a total size 

195 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are 

196 processed by multiple cores. 

197 

198 If you have texts spread across multiple files, they can be most efficiently processed in 

199 parallel if each file contains a single text (potentially collapsed from multiple lines). 

200 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be 

201 read in before bundling in order to ensure bundles are under the length limit. 

202 

203 If you are calling this function from a script, parallelization will involve rerunning 

204 that script in each process, so anything you don't want rerun should be protected by 

205 a check that `__name__` equals `"__main__"` 

206 (placed within an `if __name__ == "__main__":` clause). 

207 """ 

208 if cores > 1 and current_process().name != "MainProcess": 

209 return None 

210 if output is not None and os.path.isfile(output) and not overwrite: 

211 msg = "`output` file already exists; use `overwrite=True` to overwrite it" 

212 raise RuntimeError(msg) 

213 start_time = perf_counter() 

214 

215 if dotenv: 

216 readin_env(dotenv if isinstance(dotenv, str) else ".") 

217 dotenv = False 

218 

219 # check norming context 

220 if isinstance(custom_context, str): 

221 context = custom_context 

222 custom_context = True 

223 if context != "written": 

224 if verbose: 

225 print(f"retrieving norming contexts ({perf_counter() - start_time:.4f})") 

226 available_contexts: "list[str]" = norming(name_only=True, url=url, key=key, secret=secret, verbose=False) 

227 if ("custom/" + context if custom_context else context) not in available_contexts: 

228 msg = f"norming context {context} is not on record or is not completed" 

229 raise RuntimeError(msg) 

230 

231 # check frameworks 

232 if frameworks and version and "2" in version: 

233 if not api_args: 

234 api_args = {} 

235 if isinstance(frameworks, str): 

236 frameworks = [frameworks] 

237 api_args["frameworks"] = [f for f in frameworks if f != "summary"] 

238 if api_args and "frameworks" in api_args: 

239 arg_frameworks: "list[str]" = ( 

240 api_args["frameworks"].split(",") if isinstance(api_args["frameworks"], str) else api_args["frameworks"] 

241 ) 

242 available_frameworks = get_frameworks(url=url, key=key, secret=secret) 

243 for f in arg_frameworks: 

244 if f not in available_frameworks: 

245 msg = f"requested framework is not available to your account: {f}" 

246 raise RuntimeError(msg) 

247 if isinstance(api_args["frameworks"], list): 

248 api_args["frameworks"] = ",".join(api_args["frameworks"]) 

249 

250 if isinstance(cache, str) and cache: 

251 if clear_cache and os.path.exists(cache): 

252 shutil.rmtree(cache, True) 

253 os.makedirs(cache, exist_ok=True) 

254 if not cache_format: 

255 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet") 

256 if cache_format not in ["parquet", "feather"]: 

257 msg = "`cache_format` must be `parquet` or `feather`" 

258 raise RuntimeError(msg) 

259 else: 

260 cache = "" 

261 

262 data, res, id_specified = _manage_request( 

263 text=text, 

264 ids=ids, 

265 text_column=text_column, 

266 id_column=id_column, 

267 files=files, 

268 directory=directory, 

269 file_type=file_type, 

270 encoding=encoding, 

271 context=f"custom/{context}" if custom_context else context, 

272 api_args=api_args, 

273 bundle_size=bundle_size, 

274 bundle_byte_limit=bundle_byte_limit, 

275 collapse_lines=collapse_lines, 

276 retry_limit=retry_limit, 

277 request_cache=request_cache, 

278 cores=cores, 

279 in_memory=in_memory, 

280 verbose=verbose, 

281 progress_bar=progress_bar, 

282 make_request=make_request, 

283 text_as_paths=text_as_paths, 

284 dotenv=dotenv, 

285 cache=cache, 

286 cache_overwrite=cache_overwrite, 

287 cache_format=cache_format, 

288 key=key, 

289 secret=secret, 

290 url=url, 

291 version=version, 

292 endpoint=endpoint, 

293 ) 

294 

295 # finalize 

296 if res is None or not res.shape[0]: 

297 msg = "no results" 

298 raise RuntimeError(msg) 

299 if isinstance(cache, str): 

300 writer = _get_writer(cache_format) 

301 schema = pyarrow.schema( 

302 ( 

303 col, 

304 ( 

305 pyarrow.string() 

306 if res[col].dtype == "O" 

307 else ( 

308 pyarrow.int32() 

309 if col in ["summary.word_count", "summary.sentence_count"] 

310 else pyarrow.float32() 

311 ) 

312 ), 

313 ) 

314 for col in res.columns 

315 if col not in ["id", "bin", *(api_args.keys() if api_args else [])] 

316 ) 

317 for bin_dir in glob(cache + "/bin=*/"): 

318 _defragment_bin(bin_dir, cache_format, writer, schema) 

319 if verbose: 

320 print(f"preparing output ({perf_counter() - start_time:.4f})") 

321 data.set_index("id", inplace=True) 

322 res.set_index("id", inplace=True) 

323 if len(res) != len(data): 

324 res = res.join(data["text"]) 

325 data_absent = data.loc[list(set(data.index).difference(res.index))] 

326 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])] 

327 if data.size: 

328 res = res.reset_index() 

329 res.set_index("text", inplace=True) 

330 data_dupes = res.loc[data_absent["text"]] 

331 data_dupes["id"] = data_absent.index.to_list() 

332 res = pandas.concat([res, data_dupes]) 

333 res.reset_index(inplace=True, drop=True) 

334 res.set_index("id", inplace=True) 

335 res = res.join(data["text"], how="right") 

336 if not return_text: 

337 res.drop("text", axis=1, inplace=True) 

338 res = res.reset_index() 

339 

340 if output is not None: 

341 if verbose: 

342 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})") 

343 res.to_csv(output, index=False) 

344 

345 drops = ["custom", "bin"] 

346 if not id_specified: 

347 drops.append("id") 

348 res.drop( 

349 list({*drops}.intersection(res.columns)), 

350 axis="columns", 

351 inplace=True, 

352 ) 

353 if frameworks is not None: 

354 if verbose: 

355 print(f"selecting frameworks ({perf_counter() - start_time:.4f})") 

356 if isinstance(frameworks, str): 

357 frameworks = [frameworks] 

358 if len(frameworks) == 1 and framework_prefix is None: 

359 framework_prefix = False 

360 select = [] 

361 if id_specified: 

362 select.append("id") 

363 if return_text: 

364 select.append("text") 

365 select.append("text_hash") 

366 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)") 

367 if isinstance(framework_prefix, bool) and not framework_prefix: 

368 prefix_pattern = re.compile("^[^.]+\\.") 

369 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns]) 

370 

371 if verbose: 

372 print(f"done ({perf_counter() - start_time:.4f})") 

373 

374 return res 

375 

376 

377def _defragment_bin(bin_dir: str, write_format: str, writer, schema: pyarrow.Schema): 

378 fragments = glob(f"{bin_dir}/*.{write_format}") 

379 if len(fragments) > 1: 

380 data = pyarrow.dataset.dataset(fragments, schema, format=write_format, exclude_invalid_files=True).to_table() 

381 nrows = data.num_rows 

382 n_chunks = max(1, ceil(nrows / 1e9)) 

383 rows_per_chunk = max(1, ceil(nrows / n_chunks)) 

384 time_id = str(ceil(time())) 

385 for chunk in range(0, n_chunks, rows_per_chunk): 

386 writer(data[chunk : (chunk + rows_per_chunk)], f"{bin_dir}/part-{time_id}-{chunk}.{write_format}") 

387 for fragment in fragments: 

388 os.unlink(fragment)