Coverage for src/receptiviti/request.py: 86%

138 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-04 09:13 -0400

1"""Make requests to the API.""" 

2 

3import os 

4import re 

5import shutil 

6from glob import glob 

7from importlib.util import find_spec 

8from math import ceil 

9from multiprocessing import current_process 

10from tempfile import gettempdir 

11from time import perf_counter, time 

12from typing import Dict, List, Literal, Union 

13 

14import pandas 

15import pyarrow.dataset 

16 

17from receptiviti.frameworks import frameworks as get_frameworks 

18from receptiviti.manage_request import _get_writer, _manage_request 

19from receptiviti.norming import norming 

20from receptiviti.readin_env import readin_env 

21 

22CACHE = gettempdir() + "/receptiviti_cache/" 

23REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

24 

25 

26def request( 

27 text: Union[str, List[str], pandas.DataFrame, None] = None, 

28 output: Union[str, None] = None, 

29 ids: Union[str, List[str], List[int], None] = None, 

30 text_column: Union[str, None] = None, 

31 id_column: Union[str, None] = None, 

32 files: Union[List[str], None] = None, 

33 directory: Union[str, None] = None, 

34 file_type: str = "txt", 

35 encoding: Union[str, None] = None, 

36 return_text: bool = False, 

37 context: str = "written", 

38 custom_context: Union[str, bool] = False, 

39 api_args: Union[Dict[str, Union[str, List[str]]], None] = None, 

40 frameworks: Union[str, List[str], None] = None, 

41 framework_prefix: Union[bool, None] = None, 

42 bundle_size: int = 1000, 

43 bundle_byte_limit: float = 75e5, 

44 collapse_lines: bool = False, 

45 retry_limit: int = 50, 

46 clear_cache: bool = False, 

47 request_cache: bool = True, 

48 cores: int = 1, 

49 collect_results: bool = True, 

50 in_memory: Union[bool, None] = None, 

51 verbose: bool = False, 

52 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

53 overwrite: bool = False, 

54 make_request: bool = True, 

55 text_as_paths: bool = False, 

56 dotenv: Union[bool, str] = True, 

57 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""), 

58 cache_degragment: bool = True, 

59 cache_overwrite: bool = False, 

60 cache_format: str = os.getenv("RECEPTIVITI_CACHE_FORMAT", ""), 

61 key: str = os.getenv("RECEPTIVITI_KEY", ""), 

62 secret: str = os.getenv("RECEPTIVITI_SECRET", ""), 

63 url: str = os.getenv("RECEPTIVITI_URL", ""), 

64 version: str = os.getenv("RECEPTIVITI_VERSION", ""), 

65 endpoint: str = os.getenv("RECEPTIVITI_ENDPOINT", ""), 

66) -> Union[pandas.DataFrame, None]: 

67 """ 

68 Send texts to be scored by the API. 

69 

70 Args: 

71 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of 

72 strings containing the text itself, or the path to a file from which to read in text. 

73 If a DataFrame, `text_column` is used to extract such a vector. A string may also 

74 represent a directory in which to search for files. To best ensure paths are not 

75 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter 

76 a directory path, or `files` to enter a vector of file paths. 

77 output (str): Path to a file to write results to. 

78 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text` 

79 containing IDs. 

80 text_column (str): Column name in `text` containing text. 

81 id_column (str): Column name in `text` containing IDs. 

82 files (list[str]): Vector of file paths, as alternate entry to `text`. 

83 directory (str): A directory path to search for files in, as alternate entry to `text`. 

84 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`). 

85 encoding (str | None): Encoding of file(s) to be read in; one of the 

86 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings). 

87 If this is `None` (default), encoding will be predicted for each file, but this can 

88 potentially fail, resulting in mis-encoded characters. For best (and fastest) results, 

89 specify encoding. 

90 return_text (bool): If `True`, will include a `text` column in the output with the 

91 original text. 

92 context (str): Name of the analysis context. 

93 custom_context (str | bool): Name of a custom context (as listed by `receptiviti.norming`), 

94 or `True` if `context` is the name of a custom context. 

95 api_args (dict): Additional arguments to include in the request. 

96 frameworks (str | list): One or more names of frameworks to request. Note that this 

97 changes the results from the API, so it will invalidate any cached results 

98 without the same set of frameworks. 

99 framework_prefix (bool): If `False`, will drop framework prefix from column names. 

100 If one framework is selected, will default to `False`. 

101 bundle_size (int): Maximum number of texts per bundle. 

102 bundle_byte_limit (float): Maximum byte size of each bundle. 

103 collapse_lines (bool): If `True`, will treat files as containing single texts, and 

104 collapse multiple lines. 

105 retry_limit (int): Number of times to retry a failed request. 

106 clear_cache (bool): If `True`, will delete the `cache` before processing. 

107 request_cache (bool): If `False`, will not temporarily save raw requests for reuse 

108 within a day. 

109 cores (int): Number of CPU cores to use when processing multiple bundles. 

110 collect_results (bool): If `False`, will not retain bundle results in memory for return. 

111 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when 

112 processed. Defaults to `True` when processing in parallel. 

113 verbose (bool): If `True`, will print status messages and preserve the progress bar. 

114 progress_bar (str | bool): If `False`, will not display a progress bar. 

115 overwrite (bool): If `True`, will overwrite an existing `output` file. 

116 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths. 

117 Otherwise, this will be detected. 

118 dotenv (bool | str): Path to a .env file to read environment variables from. By default, 

119 will for a file in the current directory or `~/Documents`. 

120 Passed to `readin_env` as `path`. 

121 cache (bool | str): Path to a cache directory, or `True` to use the default directory. 

122 The cache is an Arrow dataset, and so requires the `pyarrow` package. 

123 cache_degragment (bool): If `False`, will not defragment the cache after writing new 

124 results to it. 

125 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts, 

126 but will store results in the cache (unlike `cache = False`). 

127 cache_format (str): File format of the cache, of available Arrow formats. 

128 key (str): Your API key. 

129 secret (str): Your API secret. 

130 url (str): The URL of the API; defaults to `https://api.receptiviti.com`. 

131 version (str): Version of the API; defaults to `v1`. 

132 endpoint (str): Endpoint of the API; defaults to `framework`. 

133 

134 Returns: 

135 Scores associated with each input text. 

136 

137 Examples: 

138 ```python 

139 # score a single text 

140 single = receptiviti.request("a text to score") 

141 

142 # score multiple texts, and write results to a file 

143 multi = receptiviti.request(["first text to score", "second text"], "filename.csv") 

144 

145 # score texts in separate files 

146 ## defaults to look for .txt files 

147 file_results = receptiviti.request(directory = "./path/to/txt_folder") 

148 

149 ## could be .csv 

150 file_results = receptiviti.request( 

151 directory = "./path/to/csv_folder", 

152 text_column = "text", file_type = "csv" 

153 ) 

154 

155 # score texts in a single file 

156 results = receptiviti.request("./path/to/file.csv", text_column = "text") 

157 ``` 

158 

159 Request Process: 

160 This function (along with the internal `_manage_request` function) handles texts and results in several steps: 

161 

162 1. Prepare bundles (split `text` into <= `bundle_size` and <= `bundle_byte_limit` bundles). 

163 1. If `text` points to a directory or list of files, these will be read in later. 

164 2. If `in_memory` is `False`, bundles are written to a temporary location, 

165 and read back in when the request is made. 

166 2. Get scores for texts within each bundle. 

167 1. If texts are paths, or `in_memory` is `False`, will load texts. 

168 2. If `cache` is set, will skip any texts with cached scores. 

169 3. If `request_cache` is `True`, will check for a cached request. 

170 4. If any texts need scoring and `make_request` is `True`, will send unscored texts to the API. 

171 3. If a request was made and `request_cache` is set, will cache the response. 

172 4. If `cache` is set, will write bundle scores to the cache. 

173 5. After requests are made, if `cache` is set, will defragment the cache 

174 (combine bundle results within partitions). 

175 6. If `collect_results` is `True`, will prepare results: 

176 1. Will realign results with `text` (and `ids` if provided). 

177 2. If `output` is specified, will write realigned results to it. 

178 3. Will drop additional columns (such as `custom` and `id` if not provided). 

179 4. If `framework` is specified, will use it to select columns of the results. 

180 5. Returns results. 

181 

182 Cache: 

183 If `cache` is specified, results for unique texts are saved in an Arrow database 

184 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with 

185 subsequent requests. This ensures that the exact same texts are not re-sent to the API. 

186 This does, however, add some processing time and disc space usage. 

187 

188 If `cache` if `True`, a default directory (`receptiviti_cache`) will be 

189 looked for in the system's temporary directory (`tempfile.gettempdir()`). 

190 

191 The primary cache is checked when each bundle is processed, and existing results are 

192 loaded at that time. When processing many bundles in parallel, and many results have 

193 been cached, this can cause the system to freeze and potentially crash. 

194 To avoid this, limit the number of cores, or disable parallel processing. 

195 

196 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be 

197 used to adjust the format of the cache. 

198 

199 You can use the cache independently with 

200 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`. 

201 

202 You can also set the `clear_cache` argument to `True` to clear the cache before it is used 

203 again, which may be useful if the cache has gotten big, or you know new results will be 

204 returned. 

205 

206 Even if a cached result exists, it will be reprocessed if it does not have all of the 

207 variables of new results, but this depends on there being at least 1 uncached result. If, 

208 for instance, you add a framework to your account and want to reprocess a previously 

209 processed set of texts, you would need to first clear the cache. 

210 

211 Either way, duplicated texts within the same call will only be sent once. 

212 

213 The `request_cache` argument controls a more temporary cache of each bundle request. This 

214 is cleared after a day. You might want to set this to `False` if a new framework becomes 

215 available on your account and you want to process a set of text you re-processed recently. 

216 

217 Another temporary cache is made when `in_memory` is `False`, which is the default when 

218 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a 

219 temporary directory that contains a file for each unique bundle, which is read in as needed 

220 by the parallel workers. 

221 

222 Parallelization: 

223 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents 

224 a single request to the API, which is why they are limited to 1000 texts and a total size 

225 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are 

226 processed by multiple cores. 

227 

228 If you have texts spread across multiple files, they can be most efficiently processed in 

229 parallel if each file contains a single text (potentially collapsed from multiple lines). 

230 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be 

231 read in before bundling in order to ensure bundles are under the length limit. 

232 

233 If you are calling this function from a script, parallelization will involve rerunning 

234 that script in each process, so anything you don't want rerun should be protected by 

235 a check that `__name__` equals `"__main__"` 

236 (placed within an `if __name__ == "__main__":` clause). 

237 """ 

238 if cores > 1 and current_process().name != "MainProcess": 

239 return None 

240 if output is not None and os.path.isfile(output) and not overwrite: 

241 msg = "`output` file already exists; use `overwrite=True` to overwrite it" 

242 raise RuntimeError(msg) 

243 start_time = perf_counter() 

244 

245 if dotenv: 

246 readin_env(dotenv if isinstance(dotenv, str) else ".") 

247 dotenv = False 

248 

249 # check norming context 

250 if isinstance(custom_context, str): 

251 context = custom_context 

252 custom_context = True 

253 if context != "written": 

254 if verbose: 

255 print(f"retrieving norming contexts ({perf_counter() - start_time:.4f})") 

256 available_contexts = norming(name_only=True, url=url, key=key, secret=secret, verbose=False) 

257 if ( 

258 not isinstance(available_contexts, list) 

259 or ("custom/" + context if custom_context else context) not in available_contexts 

260 ): 

261 msg = f"norming context {context} is not on record or is not completed" 

262 raise RuntimeError(msg) 

263 

264 # check frameworks 

265 if frameworks and version and "2" in version: 

266 if not api_args: 

267 api_args = {} 

268 if isinstance(frameworks, str): 

269 frameworks = [frameworks] 

270 api_args["frameworks"] = [f for f in frameworks if f != "summary"] 

271 if api_args and "frameworks" in api_args: 

272 arg_frameworks: List[str] = ( 

273 api_args["frameworks"].split(",") if isinstance(api_args["frameworks"], str) else api_args["frameworks"] 

274 ) 

275 available_frameworks = get_frameworks(url=url, key=key, secret=secret) 

276 for f in arg_frameworks: 

277 if f not in available_frameworks: 

278 msg = f"requested framework is not available to your account: {f}" 

279 raise RuntimeError(msg) 

280 if isinstance(api_args["frameworks"], list): 

281 api_args["frameworks"] = ",".join(api_args["frameworks"]) 

282 

283 if isinstance(cache, str) and cache: 

284 if find_spec("pyarrow") is None: 

285 msg = "install the `pyarrow` package to use the cache" 

286 raise RuntimeError(msg) 

287 if clear_cache and os.path.exists(cache): 

288 shutil.rmtree(cache, True) 

289 os.makedirs(cache, exist_ok=True) 

290 if not cache_format: 

291 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet") 

292 if cache_format not in ["parquet", "feather"]: 

293 msg = "`cache_format` must be `parquet` or `feather`" 

294 raise RuntimeError(msg) 

295 else: 

296 cache = "" 

297 

298 data, res, id_specified = _manage_request( 

299 text=text, 

300 ids=ids, 

301 text_column=text_column, 

302 id_column=id_column, 

303 files=files, 

304 directory=directory, 

305 file_type=file_type, 

306 encoding=encoding, 

307 context=f"custom/{context}" if custom_context else context, 

308 api_args=api_args, 

309 bundle_size=bundle_size, 

310 bundle_byte_limit=bundle_byte_limit, 

311 collapse_lines=collapse_lines, 

312 retry_limit=retry_limit, 

313 request_cache=request_cache, 

314 cores=cores, 

315 collect_results=collect_results, 

316 in_memory=in_memory, 

317 verbose=verbose, 

318 progress_bar=progress_bar, 

319 make_request=make_request, 

320 text_as_paths=text_as_paths, 

321 dotenv=dotenv, 

322 cache=cache, 

323 cache_overwrite=cache_overwrite, 

324 cache_format=cache_format, 

325 key=key, 

326 secret=secret, 

327 url=url, 

328 version=version, 

329 endpoint=endpoint, 

330 ) 

331 

332 # finalize 

333 if collect_results and (res is None or not res.shape[0]): 

334 msg = "no results" 

335 raise RuntimeError(msg) 

336 if cache and cache_degragment: 

337 writer = _get_writer(cache_format) 

338 for bin_dir in glob(cache + "/bin=*/"): 

339 _defragment_bin(bin_dir, "feather" if cache_format == "feather" else "parquet", writer) 

340 if not collect_results or res is None: 

341 if verbose: 

342 print(f"done ({perf_counter() - start_time:.4f})") 

343 return None 

344 if verbose: 

345 print(f"preparing output ({perf_counter() - start_time:.4f})") 

346 data.set_index("id", inplace=True) 

347 res.set_index("id", inplace=True) 

348 if len(res) != len(data): 

349 res = res.join(data["text"]) 

350 data_absent = data.loc[list(set(data.index).difference(res.index))] 

351 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])] 

352 if data.size: 

353 res = res.reset_index() 

354 res.set_index("text", inplace=True) 

355 data_dupes = res.loc[data_absent["text"]] 

356 data_dupes["id"] = data_absent.index.to_list() 

357 res = pandas.concat([res, data_dupes]) 

358 res.reset_index(inplace=True, drop=True) 

359 res.set_index("id", inplace=True) 

360 res = res.join(data["text"], how="right") 

361 if not return_text: 

362 res.drop("text", axis=1, inplace=True) 

363 res = res.reset_index() 

364 

365 if output is not None: 

366 if verbose: 

367 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})") 

368 res.to_csv(output, index=False) 

369 

370 drops = ["custom", "bin"] 

371 if not id_specified: 

372 drops.append("id") 

373 res.drop( 

374 list({*drops}.intersection(res.columns)), 

375 axis="columns", 

376 inplace=True, 

377 ) 

378 if frameworks is not None: 

379 if verbose: 

380 print(f"selecting frameworks ({perf_counter() - start_time:.4f})") 

381 if isinstance(frameworks, str): 

382 frameworks = [frameworks] 

383 if len(frameworks) == 1 and framework_prefix is None: 

384 framework_prefix = False 

385 select = [] 

386 if id_specified: 

387 select.append("id") 

388 if return_text: 

389 select.append("text") 

390 select.append("text_hash") 

391 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)") 

392 if isinstance(framework_prefix, bool) and not framework_prefix: 

393 prefix_pattern = re.compile("^[^.]+\\.") 

394 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns]) 

395 

396 if verbose: 

397 print(f"done ({perf_counter() - start_time:.4f})") 

398 

399 return res 

400 

401 

402def _defragment_bin(bin_dir: str, write_format: Literal["parquet", "feather"], writer): 

403 fragments = glob(f"{bin_dir}/*.{write_format}") 

404 if len(fragments) > 1: 

405 data = pyarrow.dataset.dataset(fragments, format=write_format, exclude_invalid_files=True).to_table() 

406 nrows = data.num_rows 

407 n_chunks = max(1, ceil(nrows / 1e9)) 

408 rows_per_chunk = max(1, ceil(nrows / n_chunks)) 

409 time_id = str(ceil(time())) 

410 for chunk in range(0, n_chunks, rows_per_chunk): 

411 writer(data[chunk : (chunk + rows_per_chunk)], f"{bin_dir}/part-{time_id}-{chunk}.{write_format}") 

412 for fragment in fragments: 

413 os.unlink(fragment)