Coverage for src/receptiviti/request.py: 79%

451 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-15 16:41 -0700

1"""Make requests to the API.""" 

2 

3import hashlib 

4import json 

5import math 

6import os 

7import pickle 

8import re 

9import shutil 

10import sys 

11from glob import glob 

12from multiprocessing import Process, Queue, current_process 

13from tempfile import TemporaryDirectory, gettempdir 

14from time import perf_counter, sleep, time 

15from typing import List, Union 

16 

17import numpy 

18import pandas 

19import pyarrow 

20import requests 

21from chardet.universaldetector import UniversalDetector 

22from pyarrow import compute, dataset 

23from tqdm import tqdm 

24 

25from receptiviti.readin_env import readin_env 

26from receptiviti.status import status 

27 

28CACHE = gettempdir() + "/receptiviti_cache/" 

29REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

30 

31 

32def request( 

33 text: Union[str, List[str], pandas.DataFrame, None] = None, 

34 output: Union[str, None] = None, 

35 ids: Union[str, List[str], List[int], None] = None, 

36 text_column: Union[str, None] = None, 

37 id_column: Union[str, None] = None, 

38 files: Union[List[str], None] = None, 

39 directory: Union[str, None] = None, 

40 file_type: str = "txt", 

41 encoding: Union[str, None] = None, 

42 return_text=False, 

43 api_args: Union[dict, None] = None, 

44 frameworks: Union[str, List[str], None] = None, 

45 framework_prefix: Union[bool, None] = None, 

46 bundle_size=1000, 

47 bundle_byte_limit=75e5, 

48 collapse_lines=False, 

49 retry_limit=50, 

50 clear_cache=False, 

51 request_cache=True, 

52 cores=1, 

53 in_memory: Union[bool, None] = None, 

54 verbose=False, 

55 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

56 overwrite=False, 

57 make_request=True, 

58 text_as_paths=False, 

59 dotenv: Union[bool, str] = True, 

60 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""), 

61 cache_overwrite=False, 

62 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""), 

63 key=os.getenv("RECEPTIVITI_KEY", ""), 

64 secret=os.getenv("RECEPTIVITI_SECRET", ""), 

65 url=os.getenv("RECEPTIVITI_URL", ""), 

66 version=os.getenv("RECEPTIVITI_VERSION", ""), 

67 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""), 

68) -> pandas.DataFrame: 

69 """ 

70 Send texts to be scored by the API. 

71 

72 Args: 

73 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of 

74 strings containing the text itself, or the path to a file from which to read in text. 

75 If a DataFrame, `text_column` is used to extract such a vector. A string may also 

76 represent a directory in which to search for files. To best ensure paths are not 

77 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter 

78 a directory path, or `files` to enter a vector of file paths. 

79 output (str): Path to a file to write results to. 

80 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text` 

81 containing IDs. 

82 text_column (str): Column name in `text` containing text. 

83 id_column (str): Column name in `text` containing IDs. 

84 files (list[str]): Vector of file paths, as alternate entry to `text`. 

85 directory (str): A directory path to search for files in, as alternate entry to `text`. 

86 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`). 

87 encoding (str | None): Encoding of file(s) to be read in; one of the 

88 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings). 

89 If this is `None` (default), encoding will be predicted for each file, but this can 

90 potentially fail, resulting in mis-encoded characters. For best (and fastest) results, 

91 specify encoding. 

92 return_text (bool): If `True`, will include a `text` column in the output with the 

93 original text. 

94 api_args (dict): Additional arguments to include in the request. 

95 frameworks (str | list): One or more names of frameworks to return. 

96 framework_prefix (bool): If `False`, will drop framework prefix from column names. 

97 If one framework is selected, will default to `False`. 

98 bundle_size (int): Maximum number of texts per bundle. 

99 bundle_byte_limit (float): Maximum byte size of each bundle. 

100 collapse_lines (bool): If `True`, will treat files as containing single texts, and 

101 collapse multiple lines. 

102 retry_limit (int): Number of times to retry a failed request. 

103 clear_cache (bool): If `True`, will delete the `cache` before processing. 

104 request_cache (bool): If `False`, will not temporarily save raw requests for reuse 

105 within a day. 

106 cores (int): Number of CPU cores to use when processing multiple bundles. 

107 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when 

108 processed. Defaults to `True` when processing in parallel. 

109 verbose (bool): If `True`, will print status messages and preserve the progress bar. 

110 progress_bar (str | bool): If `False`, will not display a progress bar. 

111 overwrite (bool): If `True`, will overwrite an existing `output` file. 

112 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths. 

113 Otherwise, this will be detected. 

114 dotenv (bool | str): Path to a .env file to read environment variables from. By default, 

115 will for a file in the current directory or `~/Documents`. 

116 Passed to `readin_env` as `path`. 

117 cache (bool | str): Path to a cache directory, or `True` to use the default directory. 

118 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts, 

119 but will store results in the cache (unlike `cache = False`). 

120 cache_format (str): File format of the cache, of available Arrow formats. 

121 key (str): Your API key. 

122 secret (str): Your API secret. 

123 url (str): The URL of the API; defaults to `https://api.receptiviti.com`. 

124 version (str): Version of the API; defaults to `v1`. 

125 endpoint (str): Endpoint of the API; defaults to `framework`. 

126 

127 Returns: 

128 Scores associated with each input text. 

129 

130 Cache: 

131 If `cache` is specified, results for unique texts are saved in an Arrow database 

132 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with 

133 subsequent requests. This ensures that the exact same texts are not re-sent to the API. 

134 This does, however, add some processing time and disc space usage. 

135 

136 If `cache` if `True`, a default directory (`receptiviti_cache`) will be 

137 looked for in the system's temporary directory (`tempfile.gettempdir()`). 

138 

139 The primary cache is checked when each bundle is processed, and existing results are 

140 loaded at that time. When processing many bundles in parallel, and many results have 

141 been cached, this can cause the system to freeze and potentially crash. 

142 To avoid this, limit the number of cores, or disable parallel processing. 

143 

144 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be 

145 used to adjust the format of the cache. 

146 

147 You can use the cache independently with 

148 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`. 

149 

150 You can also set the `clear_cache` argument to `True` to clear the cache before it is used 

151 again, which may be useful if the cache has gotten big, or you know new results will be 

152 returned. 

153 

154 Even if a cached result exists, it will be reprocessed if it does not have all of the 

155 variables of new results, but this depends on there being at least 1 uncached result. If, 

156 for instance, you add a framework to your account and want to reprocess a previously 

157 processed set of texts, you would need to first clear the cache. 

158 

159 Either way, duplicated texts within the same call will only be sent once. 

160 

161 The `request_cache` argument controls a more temporary cache of each bundle request. This 

162 is cleared after a day. You might want to set this to `False` if a new framework becomes 

163 available on your account and you want to process a set of text you re-processed recently. 

164 

165 Another temporary cache is made when `in_memory` is `False`, which is the default when 

166 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a 

167 temporary directory that contains a file for each unique bundle, which is read in as needed 

168 by the parallel workers. 

169 

170 Parallelization: 

171 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents 

172 a single request to the API, which is why they are limited to 1000 texts and a total size 

173 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are 

174 processed by multiple cores. 

175 

176 If you have texts spread across multiple files, they can be most efficiently processed in 

177 parallel if each file contains a single text (potentially collapsed from multiple lines). 

178 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be 

179 read in before bundling in order to ensure bundles are under the length limit. 

180 

181 If you are calling this function from a script, parallelization will involve rerunning 

182 that script in each process, so anything you don't want rerun should be protected by 

183 a check that `__name__` equals `"__main__"` 

184 (placed within an `if __name__ == "__main__":` clause). 

185 """ 

186 if cores > 1 and current_process().name != "MainProcess": 

187 return 

188 if output is not None and os.path.isfile(output) and not overwrite: 

189 msg = "`output` file already exists; use `overwrite=True` to overwrite it" 

190 raise RuntimeError(msg) 

191 start_time = perf_counter() 

192 

193 if request_cache: 

194 if verbose: 

195 print(f"preparing request cache ({perf_counter() - start_time:.4f})") 

196 _manage_request_cache() 

197 

198 # resolve credentials and check status 

199 if dotenv: 

200 readin_env("." if isinstance(dotenv, bool) else dotenv) 

201 if not url: 

202 url = os.getenv("RECEPTIVITI_URL", "https://api.receptiviti.com") 

203 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", url) 

204 if url_parts: 

205 from_url = url_parts.groups() 

206 if not version and from_url[0] is not None: 

207 version = from_url[0] 

208 if not endpoint and from_url[1] is not None: 

209 endpoint = from_url[1] 

210 url = ("https://" if re.match("http", url, re.I) is None else "") + re.sub( 

211 "/+[Vv]\\d+(?:/.*)?$|/+$", "", url 

212 ) 

213 if not key: 

214 key = os.getenv("RECEPTIVITI_KEY", "") 

215 if not secret: 

216 secret = os.getenv("RECEPTIVITI_SECRET", "") 

217 if not version: 

218 version = os.getenv("RECEPTIVITI_VERSION", "v1") 

219 if not endpoint: 

220 endpoint_default = "framework" if version.lower() == "v1" else "taxonomies" 

221 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default) 

222 api_status = status(url, key, secret, dotenv, verbose=False) 

223 if not api_status or api_status.status_code != 200: 

224 msg = ( 

225 f"API status failed: {api_status.status_code}: {api_status.reason}" 

226 if api_status 

227 else "URL is not reachable" 

228 ) 

229 raise RuntimeError(msg) 

230 

231 # resolve text and ids 

232 text_as_dir = False 

233 if text is None: 

234 if directory is not None: 

235 text = directory 

236 text_as_dir = True 

237 elif files is not None: 

238 text_as_paths = True 

239 text = files 

240 else: 

241 msg = "enter text as the first argument, or use the `files` or `directory` arguments" 

242 raise RuntimeError(msg) 

243 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260): 

244 if not text_as_dir and os.path.isfile(text): 

245 if verbose: 

246 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})") 

247 text = _readin([text], text_column, id_column, collapse_lines, encoding) 

248 if isinstance(text, pandas.DataFrame): 

249 id_column = "ids" 

250 text_column = "text" 

251 text_as_paths = False 

252 elif os.path.isdir(text): 

253 text = glob(f"{text}/*{file_type}") 

254 text_as_paths = True 

255 elif os.path.isdir(os.path.dirname(text)): 

256 msg = f"`text` appears to point to a directory, but it does not exist: {text}" 

257 raise RuntimeError(msg) 

258 if isinstance(text, pandas.DataFrame): 

259 if id_column is not None: 

260 if id_column in text: 

261 ids = text[id_column].to_list() 

262 else: 

263 msg = f"`id_column` ({id_column}) is not in `text`" 

264 raise IndexError(msg) 

265 if text_column is not None: 

266 if text_column in text: 

267 text = text[text_column].to_list() 

268 else: 

269 msg = f"`text_column` ({text_column}) is not in `text`" 

270 raise IndexError(msg) 

271 else: 

272 msg = "`text` is a DataFrame, but no `text_column` is specified" 

273 raise RuntimeError(msg) 

274 if isinstance(text, str): 

275 text = [text] 

276 text_is_path = all( 

277 isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text 

278 ) 

279 if text_as_paths and not text_is_path: 

280 msg = "`text` treated as a list of files, but not all of the entries exist" 

281 raise RuntimeError(msg) 

282 if text_is_path and not collapse_lines: 

283 ids = text 

284 text = _readin(text, text_column, id_column, collapse_lines, encoding) 

285 if isinstance(text, pandas.DataFrame): 

286 if id_column is None: 

287 ids = text["ids"].to_list() 

288 elif id_column in text: 

289 ids = text[id_column].to_list() 

290 if text_column is None: 

291 text_column = "text" 

292 text = text[text_column].to_list() 

293 text_is_path = False 

294 if ids is None and text_is_path: 

295 ids = text 

296 

297 id_specified = ids is not None 

298 if ids is None: 

299 ids = numpy.arange(1, len(text) + 1).tolist() 

300 elif len(ids) != len(text): 

301 msg = "`ids` is not the same length as `text`" 

302 raise RuntimeError(msg) 

303 original_ids = set(ids) 

304 if len(ids) != len(original_ids): 

305 msg = "`ids` contains duplicates" 

306 raise RuntimeError(msg) 

307 

308 # prepare bundles 

309 if verbose: 

310 print(f"preparing text ({perf_counter() - start_time:.4f})") 

311 data = pandas.DataFrame({"text": text, "id": ids}) 

312 n_original = len(data) 

313 data_subset = data[ 

314 ~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna()) 

315 ] 

316 n_texts = len(data_subset) 

317 if not n_texts: 

318 msg = "no valid texts to process" 

319 raise RuntimeError(msg) 

320 bundle_size = max(1, bundle_size) 

321 n_bundles = math.ceil(n_texts / min(1000, bundle_size)) 

322 groups = data_subset.groupby( 

323 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts], 

324 group_keys=False, 

325 ) 

326 bundles = [] 

327 for _, group in groups: 

328 if sys.getsizeof(group) > bundle_byte_limit: 

329 start = current = end = 0 

330 for txt in group["text"]: 

331 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt) 

332 if size > bundle_byte_limit: 

333 msg = ( 

334 "one of your texts is over the bundle size" 

335 f" limit ({bundle_byte_limit / 1e6} MB)" 

336 ) 

337 raise RuntimeError(msg) 

338 if (current + size) > bundle_byte_limit: 

339 bundles.append(group[start:end]) 

340 start = end = end + 1 

341 current = size 

342 else: 

343 end += 1 

344 current += size 

345 bundles.append(group[start:]) 

346 else: 

347 bundles.append(group) 

348 n_bundles = len(bundles) 

349 if verbose: 

350 print( 

351 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in " 

352 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}", 

353 f"({perf_counter() - start_time:.4f})", 

354 ) 

355 

356 # process bundles 

357 if isinstance(cache, str): 

358 if cache: 

359 if clear_cache and os.path.exists(cache): 

360 shutil.rmtree(cache, True) 

361 os.makedirs(cache, exist_ok=True) 

362 if not cache_format: 

363 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet") 

364 else: 

365 cache = False 

366 opts = { 

367 "url": f"{url}/{version}/{endpoint}/bulk".lower(), 

368 "auth": requests.auth.HTTPBasicAuth(key, secret), 

369 "retries": retry_limit, 

370 "add": {} if api_args is None else api_args, 

371 "request_cache": request_cache, 

372 "cache": "" if cache_overwrite or isinstance(cache, bool) and not cache else cache, 

373 "cache_format": cache_format, 

374 "make_request": make_request, 

375 "text_is_path": text_is_path, 

376 "text_column": text_column, 

377 "id_column": id_column, 

378 "collapse_lines": collapse_lines, 

379 "encoding": encoding, 

380 } 

381 opts["add_hash"] = hashlib.md5( 

382 json.dumps( 

383 {**opts["add"], "url": opts["url"], "key": key, "secret": secret}, 

384 separators=(",", ":"), 

385 ).encode() 

386 ).hexdigest() 

387 if isinstance(progress_bar, str): 

388 progress_bar = progress_bar == "True" 

389 use_pb = (verbose and progress_bar) or progress_bar 

390 parallel = n_bundles > 1 and cores > 1 

391 if in_memory is None: 

392 in_memory = not parallel 

393 with TemporaryDirectory() as scratch_cache: 

394 if not in_memory: 

395 if verbose: 

396 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})") 

397 

398 def write_to_scratch(i: int, bundle: pandas.DataFrame): 

399 temp = f"{scratch_cache}/{i}.json" 

400 with open(temp, "wb") as scratch: 

401 pickle.dump(bundle, scratch) 

402 return temp 

403 

404 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)] 

405 if parallel: 

406 if verbose: 

407 print(f"requesting in parallel ({perf_counter() - start_time:.4f})") 

408 waiter: "Queue[pandas.DataFrame]" = Queue() 

409 queue: "Queue[tuple[int, pandas.DataFrame]]" = Queue() 

410 manager = Process( 

411 target=_queue_manager, 

412 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose), 

413 ) 

414 manager.start() 

415 nb = math.ceil(n_bundles / min(n_bundles, cores)) 

416 cores = math.ceil(n_bundles / nb) 

417 procs = [ 

418 Process( 

419 target=_process, 

420 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue), 

421 ) 

422 for i in range(cores) 

423 ] 

424 for cl in procs: 

425 cl.start() 

426 res = waiter.get() 

427 else: 

428 if verbose: 

429 print(f"requesting serially ({perf_counter() - start_time:.4f})") 

430 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None 

431 res = _process(bundles, opts, pb=pb) 

432 if pb is not None: 

433 pb.close() 

434 if verbose: 

435 print(f"done requesting ({perf_counter() - start_time:.4f})") 

436 

437 # finalize 

438 if not res.shape[0]: 

439 msg = "no results" 

440 raise RuntimeError(msg) 

441 if isinstance(cache, str): 

442 _update_cache(res, cache, cache_format, verbose, start_time, [e[0] for e in opts["add"]]) 

443 if verbose: 

444 print(f"preparing output ({perf_counter() - start_time:.4f})") 

445 data.set_index("id", inplace=True) 

446 res.set_index("id", inplace=True) 

447 if len(res) != n_original: 

448 res = res.join(data["text"]) 

449 data_absent = data.loc[list(set(data.index).difference(res.index))] 

450 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])] 

451 if data.size: 

452 res = res.reset_index() 

453 res.set_index("text", inplace=True) 

454 data_dupes = res.loc[data_absent["text"]] 

455 data_dupes["id"] = data_absent.index.to_list() 

456 res = pandas.concat([res, data_dupes]) 

457 res.reset_index(inplace=True, drop=True) 

458 res.set_index("id", inplace=True) 

459 res = res.join(data["text"], how="right") 

460 if not return_text: 

461 res.drop("text", axis=1, inplace=True) 

462 res = res.reset_index() 

463 

464 if output is not None: 

465 if verbose: 

466 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})") 

467 res.to_csv(output, index=False) 

468 

469 drops = ["custom", "bin"] 

470 if not id_specified: 

471 drops.append("id") 

472 res.drop( 

473 list({*drops}.intersection(res.columns)), 

474 axis="columns", 

475 inplace=True, 

476 ) 

477 if frameworks is not None: 

478 if verbose: 

479 print(f"selecting frameworks ({perf_counter() - start_time:.4f})") 

480 if isinstance(frameworks, str): 

481 frameworks = [frameworks] 

482 if len(frameworks) == 1 and framework_prefix is None: 

483 framework_prefix = False 

484 select = [] 

485 if id_specified: 

486 select.append("id") 

487 if return_text: 

488 select.append("text") 

489 select.append("text_hash") 

490 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)") 

491 if isinstance(framework_prefix, bool) and not framework_prefix: 

492 prefix_pattern = re.compile("^[^.]+\\.") 

493 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns]) 

494 

495 if verbose: 

496 print(f"done ({perf_counter() - start_time:.4f})") 

497 

498 return res 

499 

500 

501def _queue_manager( 

502 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]", 

503 waiter: "Queue[pandas.DataFrame]", 

504 n_texts: int, 

505 n_bundles: int, 

506 use_pb=True, 

507 verbose=False, 

508): 

509 if use_pb: 

510 pb = tqdm(total=n_texts, leave=verbose) 

511 res: List[pandas.DataFrame] = [] 

512 for size, chunk in iter(queue.get, None): 

513 if isinstance(chunk, pandas.DataFrame): 

514 if use_pb: 

515 pb.update(size) 

516 res.append(chunk) 

517 if len(res) >= n_bundles: 

518 break 

519 else: 

520 break 

521 waiter.put(pandas.concat(res, ignore_index=True, sort=False)) 

522 

523 

524def _process( 

525 bundles: list, 

526 opts: dict, 

527 queue: Union["Queue[tuple[int, Union[pandas.DataFrame, None]]]", None] = None, 

528 pb: Union[tqdm, None] = None, 

529) -> pandas.DataFrame: 

530 reses: List[pandas.DataFrame] = [] 

531 for bundle in bundles: 

532 if isinstance(bundle, str): 

533 with open(bundle, "rb") as scratch: 

534 bundle = pickle.load(scratch) 

535 body = [] 

536 bundle.insert(0, "text_hash", "") 

537 if opts["text_is_path"]: 

538 bundle["text"] = _readin( 

539 bundle["text"], 

540 opts["text_column"], 

541 opts["id_column"], 

542 opts["collapse_lines"], 

543 opts["encoding"], 

544 ) 

545 for i, text in enumerate(bundle["text"]): 

546 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest() 

547 bundle.iat[i, 0] = text_hash 

548 body.append({"content": text, "request_id": text_hash, **opts["add"]}) 

549 cached = None 

550 if opts["cache"] and os.path.isdir(opts["cache"] + "/bin=h"): 

551 db = dataset.dataset( 

552 opts["cache"], 

553 partitioning=dataset.partitioning( 

554 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive" 

555 ), 

556 format=opts["cache_format"], 

557 ) 

558 if "text_hash" in db.schema.names: 

559 su = db.filter(compute.field("text_hash").isin(bundle["text_hash"])) 

560 if su.count_rows() > 0: 

561 cached = su.to_table().to_pandas(split_blocks=True, self_destruct=True) 

562 res = "failed to retrieve results" 

563 if cached is None or len(cached) < len(bundle): 

564 if cached is None or not len(cached): 

565 res = _prepare_results(body, opts) 

566 else: 

567 fresh = ~compute.is_in( 

568 bundle["text_hash"].to_list(), pyarrow.array(cached["text_hash"]) 

569 ).to_pandas(split_blocks=True, self_destruct=True) 

570 res = _prepare_results([body[i] for i, ck in enumerate(fresh) if ck], opts) 

571 if not isinstance(res, str): 

572 if cached is not None: 

573 if len(res) != len(cached) or not all(cached.columns.isin(res.columns)): 

574 cached = _prepare_results( 

575 [body[i] for i, ck in enumerate(fresh) if not ck], opts 

576 ) 

577 res = pandas.concat([res, cached]) 

578 else: 

579 res = cached 

580 if not isinstance(res, str): 

581 res = res.merge(bundle.loc[:, ["text_hash", "id"]], on="text_hash") 

582 reses.append(res) 

583 if queue is not None: 

584 queue.put((0, None) if isinstance(res, str) else (len(res), res)) 

585 elif pb is not None: 

586 pb.update(len(bundle)) 

587 if isinstance(res, str): 

588 raise RuntimeError(res) 

589 return reses[0] if len(reses) == 1 else pandas.concat(reses, ignore_index=True, sort=False) 

590 

591 

592def _prepare_results(body: list, opts: dict): 

593 json_body = json.dumps(body, separators=(",", ":")) 

594 bundle_hash = ( 

595 REQUEST_CACHE + hashlib.md5(json_body.encode()).hexdigest() + ".json" 

596 if opts["request_cache"] 

597 else "" 

598 ) 

599 raw_res = _request( 

600 json_body, 

601 opts["url"], 

602 opts["auth"], 

603 opts["retries"], 

604 bundle_hash, 

605 opts["make_request"], 

606 ) 

607 if isinstance(raw_res, str): 

608 return raw_res 

609 res = pandas.json_normalize(raw_res) 

610 res.rename(columns={"request_id": "text_hash"}, inplace=True) 

611 if "text_hash" not in res: 

612 res.insert(0, "text_hash", [text["request_id"] for text in body]) 

613 res.drop( 

614 list({"response_id", "language", "version", "error"}.intersection(res.columns)), 

615 axis="columns", 

616 inplace=True, 

617 ) 

618 res.insert(res.shape[1], "bin", ["h" + h[0] for h in res["text_hash"]]) 

619 return res 

620 

621 

622def _request( 

623 body: str, 

624 url: str, 

625 auth: requests.auth.HTTPBasicAuth, 

626 retries: int, 

627 cache="", 

628 execute=True, 

629) -> Union[dict, str]: 

630 if not os.path.isfile(cache): 

631 if not execute: 

632 return "`make_request` is `False`, but there are texts with no cached results" 

633 res = requests.post(url, body, auth=auth, timeout=9999) 

634 if cache and res.status_code == 200: 

635 with open(cache, "w", encoding="utf-8") as response: 

636 json.dump(res.json(), response) 

637 else: 

638 with open(cache, encoding="utf-8") as response: 

639 data = json.load(response) 

640 return data["results"] if "results" in data else data 

641 if res.status_code == 200: 

642 data = res.json() 

643 data = dict(data[0] if isinstance(data, list) else data) 

644 return data["results"] if "results" in data else data 

645 if os.path.isfile(cache): 

646 os.remove(cache) 

647 if retries > 0: 

648 cd = re.search( 

649 "[0-9]+(?:\\.[0-9]+)?", 

650 ( 

651 res.json()["message"] 

652 if res.headers["Content-Type"] == "application/json" 

653 else res.text 

654 ), 

655 ) 

656 sleep(1 if cd is None else float(cd[0]) / 1e3) 

657 return _request(body, url, auth, retries - 1, cache) 

658 return f"request failed, and have no retries: {res.status_code}: {res.reason}" 

659 

660 

661def _update_cache( 

662 res: pandas.DataFrame, 

663 cache: str, 

664 cache_format: str, 

665 verbose: bool, 

666 start_time: float, 

667 add_names: list, 

668): 

669 part: pyarrow.Partitioning = dataset.partitioning( 

670 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive" 

671 ) 

672 exclude = {"id", *add_names} 

673 

674 def initialize_cache(): 

675 initial = res.iloc[[0]].drop( 

676 exclude.intersection(res.columns), 

677 axis="columns", 

678 ) 

679 initial["text_hash"] = "" 

680 initial["bin"] = "h" 

681 initial.loc[ 

682 :, 

683 ~initial.columns.isin(["summary.word_count", "summary.sentence_count"]) 

684 & (initial.dtypes != object).to_list(), 

685 ] = 0.1 

686 dataset.write_dataset( 

687 pyarrow.Table.from_pandas(initial), 

688 cache, 

689 partitioning=part, 

690 format=cache_format, 

691 existing_data_behavior="overwrite_or_ignore", 

692 ) 

693 

694 if not os.path.isdir(cache + "/bin=h"): 

695 if verbose: 

696 print(f"initializing cache ({perf_counter() - start_time:.4f})") 

697 initialize_cache() 

698 db = dataset.dataset(cache, partitioning=part, format=cache_format) 

699 if any(name not in exclude and name not in db.schema.names for name in res.columns.to_list()): 

700 if verbose: 

701 print( 

702 "clearing cache since it contains columns not in new results", 

703 f"({perf_counter() - start_time:.4f})", 

704 ) 

705 shutil.rmtree(cache, True) 

706 initialize_cache() 

707 db = dataset.dataset(cache, partitioning=part, format=cache_format) 

708 fresh = res[~res.duplicated(subset=["text_hash"])] 

709 su = db.filter(compute.field("text_hash").isin(fresh["text_hash"])) 

710 if su.count_rows() > 0: 

711 cached = ~compute.is_in( 

712 fresh["text_hash"].to_list(), 

713 su.scanner(columns=["text_hash"]).to_table()["text_hash"], 

714 ).to_pandas(split_blocks=True, self_destruct=True) 

715 if any(cached): 

716 fresh = fresh[cached.to_list()] 

717 else: 

718 return 

719 n_new = len(fresh) 

720 if n_new: 

721 if verbose: 

722 print( 

723 f"adding {n_new} result{'' if n_new == 1 else 's'}", 

724 f"to cache ({perf_counter() - start_time:.4f})", 

725 ) 

726 dataset.write_dataset( 

727 pyarrow.Table.from_pandas( 

728 fresh.drop( 

729 list(exclude.intersection(fresh.columns)), 

730 axis="columns", 

731 ) 

732 ), 

733 cache, 

734 partitioning=part, 

735 format=cache_format, 

736 existing_data_behavior="overwrite_or_ignore", 

737 ) 

738 

739 

740def _manage_request_cache(): 

741 os.makedirs(REQUEST_CACHE, exist_ok=True) 

742 try: 

743 refreshed = time() 

744 log_file = REQUEST_CACHE + "log.txt" 

745 if os.path.exists(log_file): 

746 with open(log_file, encoding="utf-8") as log: 

747 logged = log.readline() 

748 if isinstance(logged, list): 

749 logged = logged[0] 

750 refreshed = float(logged) 

751 else: 

752 with open(log_file, "w", encoding="utf-8") as log: 

753 log.write(str(time())) 

754 if time() - refreshed > 86400: 

755 for cached_request in glob(REQUEST_CACHE + "*.json"): 

756 os.remove(cached_request) 

757 except Exception as exc: 

758 msg = "failed to manage request cache" 

759 raise RuntimeWarning(msg) from exc 

760 

761 

762def _readin( 

763 paths: List[str], 

764 text_column: Union[str, None], 

765 id_column: Union[str, None], 

766 collapse_lines: bool, 

767 encoding: Union[str, None], 

768) -> Union[List[str], pandas.DataFrame]: 

769 text = [] 

770 ids = [] 

771 sel = [] 

772 if text_column is not None: 

773 sel.append(text_column) 

774 if id_column is not None: 

775 sel.append(id_column) 

776 enc = encoding 

777 predict_encoding = enc is None 

778 if predict_encoding: 

779 detect = UniversalDetector() 

780 

781 def handle_encoding(file: str): 

782 detect.reset() 

783 with open(file, "rb") as text: 

784 detect.feed(text.read()) 

785 return detect.close()["encoding"] 

786 

787 if os.path.splitext(paths[0])[1] == ".txt" and not sel: 

788 if collapse_lines: 

789 for file in paths: 

790 if predict_encoding: 

791 enc = handle_encoding(file) 

792 with open(file, encoding=enc, errors="ignore") as texts: 

793 text.append(" ".join([line.rstrip() for line in texts])) 

794 else: 

795 for file in paths: 

796 if predict_encoding: 

797 enc = handle_encoding(file) 

798 with open(file, encoding=enc, errors="ignore") as texts: 

799 lines = [line.rstrip() for line in texts] 

800 text += lines 

801 ids += ( 

802 [file] 

803 if len(lines) == 1 

804 else [file + str(i + 1) for i in range(len(lines))] 

805 ) 

806 return pandas.DataFrame({"text": text, "ids": ids}) 

807 elif collapse_lines: 

808 for file in paths: 

809 if predict_encoding: 

810 enc = handle_encoding(file) 

811 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

812 text.append(" ".join(temp[text_column])) 

813 else: 

814 for file in paths: 

815 if predict_encoding: 

816 enc = handle_encoding(file) 

817 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

818 if text_column not in temp: 

819 msg = f"`text_column` ({text_column}) was not found in all files" 

820 raise IndexError(msg) 

821 text += temp[text_column].to_list() 

822 ids += ( 

823 temp[id_column].to_list() 

824 if id_column is not None 

825 else [file] if len(temp) == 1 else [file + str(i + 1) for i in range(len(temp))] 

826 ) 

827 return pandas.DataFrame({"text": text, "ids": ids}) 

828 return text