Coverage for src/receptiviti/request.py: 81%

451 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-20 11:35 -0500

1"""Make requests to the API.""" 

2 

3import hashlib 

4import json 

5import math 

6import os 

7import pickle 

8import re 

9import shutil 

10import sys 

11from glob import glob 

12from multiprocessing import Process, Queue, cpu_count 

13from tempfile import TemporaryDirectory, gettempdir 

14from time import perf_counter, sleep, time 

15from typing import List, Union 

16 

17import numpy 

18import pandas 

19import pyarrow 

20import requests 

21from chardet.universaldetector import UniversalDetector 

22from pyarrow import compute, dataset 

23from tqdm import tqdm 

24 

25from receptiviti.readin_env import readin_env 

26from receptiviti.status import status 

27 

28CACHE = gettempdir() + "/receptiviti_cache/" 

29REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

30 

31 

32def request( 

33 text: Union[str, List[str], pandas.DataFrame, None] = None, 

34 output: Union[str, None] = None, 

35 ids: Union[str, List[str], List[int], None] = None, 

36 text_column: Union[str, None] = None, 

37 id_column: Union[str, None] = None, 

38 files: Union[List[str], None] = None, 

39 directory: Union[str, None] = None, 

40 file_type: str = "txt", 

41 encoding: Union[str, None] = None, 

42 return_text=False, 

43 api_args: Union[dict, None] = None, 

44 frameworks: Union[str, List[str], None] = None, 

45 framework_prefix: Union[bool, None] = None, 

46 bundle_size=1000, 

47 bundle_byte_limit=75e5, 

48 collapse_lines=False, 

49 retry_limit=50, 

50 clear_cache=False, 

51 request_cache=True, 

52 cores=cpu_count() - 2, 

53 in_memory: Union[bool, None] = None, 

54 verbose=False, 

55 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

56 overwrite=False, 

57 make_request=True, 

58 text_as_paths=False, 

59 dotenv: Union[bool, str] = True, 

60 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""), 

61 cache_overwrite=False, 

62 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""), 

63 key=os.getenv("RECEPTIVITI_KEY", ""), 

64 secret=os.getenv("RECEPTIVITI_SECRET", ""), 

65 url=os.getenv("RECEPTIVITI_URL", ""), 

66 version=os.getenv("RECEPTIVITI_VERSION", ""), 

67 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""), 

68) -> pandas.DataFrame: 

69 """ 

70 Send texts to be scored by the API. 

71 

72 Args: 

73 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of 

74 strings containing the text itself, or the path to a file from which to read in text. 

75 If a DataFrame, `text_column` is used to extract such a vector. A string may also 

76 represent a directory in which to search for files. To best ensure paths are not 

77 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter 

78 a directory path, or `files` to enter a vector of file paths. 

79 output (str): Path to a file to write results to. 

80 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text` 

81 containing IDs. 

82 text_column (str): Column name in `text` containing text. 

83 id_column (str): Column name in `text` containing IDs. 

84 files (list[str]): Vector of file paths, as alternate entry to `text`. 

85 directory (str): A directory path to search for files in, as alternate entry to `text`. 

86 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`). 

87 encoding (str | None): Encoding of file(s) to be read in; one of the 

88 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings). 

89 If this is `None` (default), encoding will be predicted for each file, but this can 

90 potentially fail, resulting in mis-encoded characters. For best (and fastest) results, 

91 specify encoding. 

92 return_text (bool): If `True`, will include a `text` column in the output with the 

93 original text. 

94 api_args (dict): Additional arguments to include in the request. 

95 frameworks (str | list): One or more names of frameworks to return. 

96 framework_prefix (bool): If `False`, will drop framework prefix from column names. 

97 If one framework is selected, will default to `False`. 

98 bundle_size (int): Maximum number of texts per bundle. 

99 bundle_byte_limit (float): Maximum byte size of each bundle. 

100 collapse_lines (bool): If `True`, will treat files as containing single texts, and 

101 collapse multiple lines. 

102 retry_limit (int): Number of times to retry a failed request. 

103 clear_cache (bool): If `True`, will delete the `cache` before processing. 

104 request_cache (bool): If `False`, will not temporarily save raw requests for reuse 

105 within a day. 

106 cores (int): Number of CPU cores to use when processing multiple bundles. 

107 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when 

108 processed. Defaults to `True` when processing in parallel. 

109 verbose (bool): If `True`, will print status messages and preserve the progress bar. 

110 progress_bar (str | bool): If `False`, will not display a progress bar. 

111 overwrite (bool): If `True`, will overwrite an existing `output` file. 

112 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths. 

113 Otherwise, this will be detected. 

114 dotenv (bool | str): Path to a .env file to read environment variables from. By default, 

115 will for a file in the current directory or `~/Documents`. 

116 Passed to `readin_env` as `path`. 

117 cache (bool | str): Path to a cache directory, or `True` to use the default directory. 

118 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts, 

119 but will store results in the cache (unlike `cache = False`). 

120 cache_format (str): File format of the cache, of available Arrow formats. 

121 key (str): Your API key. 

122 secret (str): Your API secret. 

123 url (str): The URL of the API; defaults to `https://api.receptiviti.com`. 

124 version (str): Version of the API; defaults to `v1`. 

125 endpoint (str): Endpoint of the API; defaults to `framework`. 

126 

127 Returns: 

128 Scores associated with each input text. 

129 

130 Cache: 

131 If `cache` is specified, results for unique texts are saved in an Arrow database 

132 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with 

133 subsequent requests. This ensures that the exact same texts are not re-sent to the API. 

134 This does, however, add some processing time and disc space usage. 

135 

136 If `cache` if `True`, a default directory (`receptiviti_cache`) will be 

137 looked for in the system's temporary directory (`tempfile.gettempdir()`). 

138 

139 The primary cache is checked when each bundle is processed, and existing results are 

140 loaded at that time. When processing many bundles in parallel, and many results have 

141 been cached, this can cause the system to freeze and potentially crash. 

142 To avoid this, limit the number of cores, or disable parallel processing. 

143 

144 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be 

145 used to adjust the format of the cache. 

146 

147 You can use the cache independently with 

148 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`. 

149 

150 You can also set the `clear_cache` argument to `True` to clear the cache before it is used 

151 again, which may be useful if the cache has gotten big, or you know new results will be 

152 returned. 

153 

154 Even if a cached result exists, it will be reprocessed if it does not have all of the 

155 variables of new results, but this depends on there being at least 1 uncached result. If, 

156 for instance, you add a framework to your account and want to reprocess a previously 

157 processed set of texts, you would need to first clear the cache. 

158 

159 Either way, duplicated texts within the same call will only be sent once. 

160 

161 The `request_cache` argument controls a more temporary cache of each bundle request. This 

162 is cleared after a day. You might want to set this to `False` if a new framework becomes 

163 available on your account and you want to process a set of text you re-processed recently. 

164 

165 Another temporary cache is made when `in_memory` is `False`, which is the default when 

166 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a 

167 temporary directory that contains a file for each unique bundle, which is read in as needed 

168 by the parallel workers. 

169 

170 Parallelization: 

171 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents 

172 a single request to the API, which is why they are limited to 1000 texts and a total size 

173 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are 

174 processed by multiple cores. 

175 

176 If you have texts spread across multiple files, they can be most efficiently processed in 

177 parallel if each file contains a single text (potentially collapsed from multiple lines). 

178 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be 

179 read in before bundling in order to ensure bundles are under the length limit. 

180 """ 

181 if output is not None and os.path.isfile(output) and not overwrite: 

182 msg = "`output` file already exists; use `overwrite=True` to overwrite it" 

183 raise RuntimeError(msg) 

184 start_time = perf_counter() 

185 

186 if request_cache: 

187 if verbose: 

188 print(f"preparing request cache ({perf_counter() - start_time:.4f})") 

189 _manage_request_cache() 

190 

191 # resolve credentials and check status 

192 if dotenv: 

193 readin_env("." if isinstance(dotenv, bool) else dotenv) 

194 if not url: 

195 url = os.getenv("RECEPTIVITI_URL", "https://api.receptiviti.com") 

196 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", url) 

197 if url_parts: 

198 from_url = url_parts.groups() 

199 if not version and from_url[0] is not None: 

200 version = from_url[0] 

201 if not endpoint and from_url[1] is not None: 

202 endpoint = from_url[1] 

203 url = ("https://" if re.match("http", url, re.I) is None else "") + re.sub( 

204 "/+[Vv]\\d+(?:/.*)?$|/+$", "", url 

205 ) 

206 if not key: 

207 key = os.getenv("RECEPTIVITI_KEY", "") 

208 if not secret: 

209 secret = os.getenv("RECEPTIVITI_SECRET", "") 

210 if not version: 

211 version = os.getenv("RECEPTIVITI_VERSION", "v1") 

212 if not endpoint: 

213 endpoint_default = "framework" if version.lower() == "v1" else "taxonomies" 

214 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default) 

215 api_status = status(url, key, secret, dotenv, verbose=False) 

216 if not api_status or api_status.status_code != 200: 

217 msg = ( 

218 f"API status failed: {api_status.status_code}: {api_status.reason}" 

219 if api_status 

220 else "URL is not reachable" 

221 ) 

222 raise RuntimeError(msg) 

223 

224 # resolve text and ids 

225 text_as_dir = False 

226 if text is None: 

227 if directory is not None: 

228 text = directory 

229 text_as_dir = True 

230 elif files is not None: 

231 text_as_paths = True 

232 text = files 

233 else: 

234 msg = "enter text as the first argument, or use the `files` or `directory` arguments" 

235 raise RuntimeError(msg) 

236 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260): 

237 if not text_as_dir and os.path.isfile(text): 

238 if verbose: 

239 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})") 

240 text = _readin([text], text_column, id_column, collapse_lines, encoding) 

241 if isinstance(text, pandas.DataFrame): 

242 id_column = "ids" 

243 text_column = "text" 

244 text_as_paths = False 

245 elif os.path.isdir(text): 

246 text = glob(f"{text}/*{file_type}") 

247 text_as_paths = True 

248 elif os.path.isdir(os.path.dirname(text)): 

249 msg = f"`text` appears to point to a directory, but it does not exist: {text}" 

250 raise RuntimeError(msg) 

251 if isinstance(text, pandas.DataFrame): 

252 if id_column is not None: 

253 if id_column in text: 

254 ids = text[id_column].to_list() 

255 else: 

256 msg = f"`id_column` ({id_column}) is not in `text`" 

257 raise IndexError(msg) 

258 if text_column is not None: 

259 if text_column in text: 

260 text = text[text_column].to_list() 

261 else: 

262 msg = f"`text_column` ({text_column}) is not in `text`" 

263 raise IndexError(msg) 

264 else: 

265 msg = "`text` is a DataFrame, but no `text_column` is specified" 

266 raise RuntimeError(msg) 

267 if isinstance(text, str): 

268 text = [text] 

269 text_is_path = all( 

270 isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text 

271 ) 

272 if text_as_paths and not text_is_path: 

273 msg = "`text` treated as a list of files, but not all of the entries exist" 

274 raise RuntimeError(msg) 

275 if text_is_path and not collapse_lines: 

276 ids = text 

277 text = _readin(text, text_column, id_column, collapse_lines, encoding) 

278 if isinstance(text, pandas.DataFrame): 

279 if id_column is None: 

280 ids = text["ids"].to_list() 

281 elif id_column in text: 

282 ids = text[id_column].to_list() 

283 if text_column is None: 

284 text_column = "text" 

285 text = text[text_column].to_list() 

286 text_is_path = False 

287 if ids is None and text_is_path: 

288 ids = text 

289 

290 id_specified = ids is not None 

291 if ids is None: 

292 ids = numpy.arange(1, len(text) + 1).tolist() 

293 elif len(ids) != len(text): 

294 msg = "`ids` is not the same length as `text`" 

295 raise RuntimeError(msg) 

296 original_ids = set(ids) 

297 if len(ids) != len(original_ids): 

298 msg = "`ids` contains duplicates" 

299 raise RuntimeError(msg) 

300 

301 # prepare bundles 

302 if verbose: 

303 print(f"preparing text ({perf_counter() - start_time:.4f})") 

304 data = pandas.DataFrame({"text": text, "id": ids}) 

305 n_original = len(data) 

306 data_subset = data[ 

307 ~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna()) 

308 ] 

309 n_texts = len(data_subset) 

310 if not n_texts: 

311 msg = "no valid texts to process" 

312 raise RuntimeError(msg) 

313 bundle_size = max(1, bundle_size) 

314 n_bundles = math.ceil(n_texts / min(1000, bundle_size)) 

315 groups = data_subset.groupby( 

316 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts], 

317 group_keys=False, 

318 ) 

319 bundles = [] 

320 for _, group in groups: 

321 if sys.getsizeof(group) > bundle_byte_limit: 

322 start = current = end = 0 

323 for txt in group["text"]: 

324 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt) 

325 if size > bundle_byte_limit: 

326 msg = ( 

327 "one of your texts is over the bundle size" 

328 f" limit ({bundle_byte_limit / 1e6} MB)" 

329 ) 

330 raise RuntimeError(msg) 

331 if (current + size) > bundle_byte_limit: 

332 bundles.append(group[start:end]) 

333 start = end = end + 1 

334 current = size 

335 else: 

336 end += 1 

337 current += size 

338 bundles.append(group[start:]) 

339 else: 

340 bundles.append(group) 

341 n_bundles = len(bundles) 

342 if verbose: 

343 print( 

344 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in " 

345 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}", 

346 f"({perf_counter() - start_time:.4f})", 

347 ) 

348 

349 # process bundles 

350 if isinstance(cache, str): 

351 if cache: 

352 if clear_cache and os.path.exists(cache): 

353 shutil.rmtree(cache, True) 

354 os.makedirs(cache, exist_ok=True) 

355 if not cache_format: 

356 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet") 

357 else: 

358 cache = False 

359 opts = { 

360 "url": f"{url}/{version}/{endpoint}/bulk".lower(), 

361 "auth": requests.auth.HTTPBasicAuth(key, secret), 

362 "retries": retry_limit, 

363 "add": {} if api_args is None else api_args, 

364 "request_cache": request_cache, 

365 "cache": "" if cache_overwrite or isinstance(cache, bool) and not cache else cache, 

366 "cache_format": cache_format, 

367 "make_request": make_request, 

368 "text_is_path": text_is_path, 

369 "text_column": text_column, 

370 "id_column": id_column, 

371 "collapse_lines": collapse_lines, 

372 "encoding": encoding, 

373 } 

374 opts["add_hash"] = hashlib.md5( 

375 json.dumps( 

376 {**opts["add"], "url": opts["url"], "key": key, "secret": secret}, 

377 separators=(",", ":"), 

378 ).encode() 

379 ).hexdigest() 

380 if isinstance(progress_bar, str): 

381 progress_bar = progress_bar == "True" 

382 use_pb = (verbose and progress_bar) or progress_bar 

383 parallel = n_bundles > 1 and cores > 1 

384 if in_memory is None: 

385 in_memory = not parallel 

386 with TemporaryDirectory() as scratch_cache: 

387 if not in_memory: 

388 if verbose: 

389 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})") 

390 

391 def write_to_scratch(i: int, bundle: pandas.DataFrame): 

392 temp = f"{scratch_cache}/{i}.json" 

393 with open(temp, "wb") as scratch: 

394 pickle.dump(bundle, scratch) 

395 return temp 

396 

397 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)] 

398 if parallel: 

399 if verbose: 

400 print(f"requesting in parallel ({perf_counter() - start_time:.4f})") 

401 waiter: "Queue[pandas.DataFrame]" = Queue() 

402 queue: "Queue[tuple[int, pandas.DataFrame]]" = Queue() 

403 manager = Process( 

404 target=_queue_manager, 

405 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose), 

406 ) 

407 manager.start() 

408 nb = math.ceil(n_bundles / min(n_bundles, cores)) 

409 cores = math.ceil(n_bundles / nb) 

410 procs = [ 

411 Process( 

412 target=_process, 

413 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue), 

414 ) 

415 for i in range(cores) 

416 ] 

417 for cl in procs: 

418 cl.start() 

419 for cl in procs: 

420 cl.join() 

421 res = waiter.get() 

422 else: 

423 if verbose: 

424 print(f"requesting serially ({perf_counter() - start_time:.4f})") 

425 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None 

426 res = _process(bundles, opts, pb=pb) 

427 if pb is not None: 

428 pb.close() 

429 if verbose: 

430 print(f"done requesting ({perf_counter() - start_time:.4f})") 

431 

432 # finalize 

433 if not res.shape[0]: 

434 msg = "no results" 

435 raise RuntimeError(msg) 

436 if isinstance(cache, str): 

437 _update_cache(res, cache, cache_format, verbose, start_time, [e[0] for e in opts["add"]]) 

438 if verbose: 

439 print(f"preparing output ({perf_counter() - start_time:.4f})") 

440 data.set_index("id", inplace=True) 

441 res.set_index("id", inplace=True) 

442 if len(res) != n_original: 

443 res = res.join(data["text"]) 

444 data_absent = data.loc[list(set(data.index).difference(res.index))] 

445 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])] 

446 if data.size: 

447 res = res.reset_index() 

448 res.set_index("text", inplace=True) 

449 data_dupes = res.loc[data_absent["text"]] 

450 data_dupes["id"] = data_absent.index.to_list() 

451 res = pandas.concat([res, data_dupes]) 

452 res.reset_index(inplace=True, drop=True) 

453 res.set_index("id", inplace=True) 

454 res = res.join(data["text"], how="outer") 

455 if not return_text: 

456 res.drop("text", axis=1, inplace=True) 

457 res = res.reset_index() 

458 

459 if output is not None: 

460 if verbose: 

461 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})") 

462 res.to_csv(output, index=False) 

463 

464 drops = ["custom", "bin"] 

465 if not id_specified: 

466 drops.append("id") 

467 res.drop( 

468 list({*drops}.intersection(res.columns)), 

469 axis="columns", 

470 inplace=True, 

471 ) 

472 if frameworks is not None: 

473 if verbose: 

474 print(f"selecting frameworks ({perf_counter() - start_time:.4f})") 

475 if isinstance(frameworks, str): 

476 frameworks = [frameworks] 

477 if len(frameworks) == 1 and framework_prefix is None: 

478 framework_prefix = False 

479 select = [] 

480 if id_specified: 

481 select.append("id") 

482 if return_text: 

483 select.append("text") 

484 select.append("text_hash") 

485 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)") 

486 if isinstance(framework_prefix, bool) and not framework_prefix: 

487 prefix_pattern = re.compile("^[^.]+\\.") 

488 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns]) 

489 

490 if verbose: 

491 print(f"done ({perf_counter() - start_time:.4f})") 

492 

493 return res 

494 

495 

496def _queue_manager( 

497 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]", 

498 waiter: "Queue[pandas.DataFrame]", 

499 n_texts: int, 

500 n_bundles: int, 

501 use_pb=True, 

502 verbose=False, 

503): 

504 if use_pb: 

505 pb = tqdm(total=n_texts, leave=verbose) 

506 res: List[pandas.DataFrame] = [] 

507 for size, chunk in iter(queue.get, None): 

508 if isinstance(chunk, pandas.DataFrame): 

509 if use_pb: 

510 pb.update(size) 

511 res.append(chunk) 

512 if len(res) >= n_bundles: 

513 break 

514 else: 

515 break 

516 waiter.put(pandas.concat(res, ignore_index=True, sort=False)) 

517 

518 

519def _process( 

520 bundles: list, 

521 opts: dict, 

522 queue: Union["Queue[tuple[int, Union[pandas.DataFrame, None]]]", None] = None, 

523 pb: Union[tqdm, None] = None, 

524) -> pandas.DataFrame: 

525 reses: List[pandas.DataFrame] = [] 

526 for bundle in bundles: 

527 if isinstance(bundle, str): 

528 with open(bundle, "rb") as scratch: 

529 bundle = pickle.load(scratch) 

530 body = [] 

531 bundle.insert(0, "text_hash", "") 

532 if opts["text_is_path"]: 

533 bundle["text"] = _readin( 

534 bundle["text"], 

535 opts["text_column"], 

536 opts["id_column"], 

537 opts["collapse_lines"], 

538 opts["encoding"], 

539 ) 

540 for i, text in enumerate(bundle["text"]): 

541 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest() 

542 bundle.iat[i, 0] = text_hash 

543 body.append({"content": text, "request_id": text_hash, **opts["add"]}) 

544 cached = None 

545 if opts["cache"] and os.path.isdir(opts["cache"] + "/bin=h"): 

546 db = dataset.dataset( 

547 opts["cache"], 

548 partitioning=dataset.partitioning( 

549 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive" 

550 ), 

551 format=opts["cache_format"], 

552 ) 

553 if "text_hash" in db.schema.names: 

554 su = db.filter(compute.field("text_hash").isin(bundle["text_hash"])) 

555 if su.count_rows() > 0: 

556 cached = su.to_table().to_pandas(split_blocks=True, self_destruct=True) 

557 res = "failed to retrieve results" 

558 if cached is None or len(cached) < len(bundle): 

559 if cached is None or not len(cached): 

560 res = _prepare_results(body, opts) 

561 else: 

562 fresh = ~compute.is_in( 

563 bundle["text_hash"].to_list(), pyarrow.array(cached["text_hash"]) 

564 ).to_pandas(split_blocks=True, self_destruct=True) 

565 res = _prepare_results([body[i] for i, ck in enumerate(fresh) if ck], opts) 

566 if not isinstance(res, str): 

567 if cached is not None: 

568 if len(res) != len(cached) or not all(cached.columns.isin(res.columns)): 

569 cached = _prepare_results( 

570 [body[i] for i, ck in enumerate(fresh) if not ck], opts 

571 ) 

572 res = pandas.concat([res, cached]) 

573 else: 

574 res = cached 

575 if not isinstance(res, str): 

576 res = res.merge(bundle.loc[:, ["text_hash", "id"]], on="text_hash") 

577 reses.append(res) 

578 if queue is not None: 

579 queue.put((0, None) if isinstance(res, str) else (len(res), res)) 

580 elif pb is not None: 

581 pb.update(len(bundle)) 

582 if isinstance(res, str): 

583 raise RuntimeError(res) 

584 return reses[0] if len(reses) == 1 else pandas.concat(reses, ignore_index=True, sort=False) 

585 

586 

587def _prepare_results(body: list, opts: dict): 

588 json_body = json.dumps(body, separators=(",", ":")) 

589 bundle_hash = ( 

590 REQUEST_CACHE + hashlib.md5(json_body.encode()).hexdigest() + ".json" 

591 if opts["request_cache"] 

592 else "" 

593 ) 

594 raw_res = _request( 

595 json_body, 

596 opts["url"], 

597 opts["auth"], 

598 opts["retries"], 

599 bundle_hash, 

600 opts["make_request"], 

601 ) 

602 if isinstance(raw_res, str): 

603 return raw_res 

604 res = pandas.json_normalize(raw_res) 

605 res.rename(columns={"request_id": "text_hash"}, inplace=True) 

606 if "text_hash" not in res: 

607 res.insert(0, "text_hash", [text["request_id"] for text in body]) 

608 res.drop( 

609 list({"response_id", "language", "version", "error"}.intersection(res.columns)), 

610 axis="columns", 

611 inplace=True, 

612 ) 

613 res.insert(res.shape[1], "bin", ["h" + h[0] for h in res["text_hash"]]) 

614 return res 

615 

616 

617def _request( 

618 body: str, 

619 url: str, 

620 auth: requests.auth.HTTPBasicAuth, 

621 retries: int, 

622 cache="", 

623 execute=True, 

624) -> Union[dict, str]: 

625 if not os.path.isfile(cache): 

626 if not execute: 

627 return "`make_request` is `False`, but there are texts with no cached results" 

628 res = requests.post(url, body, auth=auth, timeout=9999) 

629 if cache and res.status_code == 200: 

630 with open(cache, "w", encoding="utf-8") as response: 

631 json.dump(res.json(), response) 

632 else: 

633 with open(cache, encoding="utf-8") as response: 

634 data = json.load(response) 

635 return data["results"] if "results" in data else data 

636 if res.status_code == 200: 

637 data = res.json() 

638 data = dict(data[0] if isinstance(data, list) else data) 

639 return data["results"] if "results" in data else data 

640 if os.path.isfile(cache): 

641 os.remove(cache) 

642 if retries > 0: 

643 cd = re.search( 

644 "[0-9]+(?:\\.[0-9]+)?", 

645 res.json()["message"] 

646 if res.headers["Content-Type"] == "application/json" 

647 else res.text, 

648 ) 

649 sleep(1 if cd is None else float(cd[0]) / 1e3) 

650 return _request(body, url, auth, retries - 1, cache) 

651 return f"request failed, and have no retries: {res.status_code}: {res.reason}" 

652 

653 

654def _update_cache( 

655 res: pandas.DataFrame, 

656 cache: str, 

657 cache_format: str, 

658 verbose: bool, 

659 start_time: float, 

660 add_names: list, 

661): 

662 part: pyarrow.Partitioning = dataset.partitioning( 

663 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive" 

664 ) 

665 exclude = {"id", *add_names} 

666 

667 def initialize_cache(): 

668 initial = res.iloc[[0]].drop( 

669 exclude.intersection(res.columns), 

670 axis="columns", 

671 ) 

672 initial["text_hash"] = "" 

673 initial["bin"] = "h" 

674 initial.loc[ 

675 :, 

676 ~initial.columns.isin(["summary.word_count", "summary.sentence_count"]) 

677 & (initial.dtypes != object).to_list(), 

678 ] = 0.1 

679 dataset.write_dataset( 

680 pyarrow.Table.from_pandas(initial), 

681 cache, 

682 partitioning=part, 

683 format=cache_format, 

684 existing_data_behavior="overwrite_or_ignore", 

685 ) 

686 

687 if not os.path.isdir(cache + "/bin=h"): 

688 if verbose: 

689 print(f"initializing cache ({perf_counter() - start_time:.4f})") 

690 initialize_cache() 

691 db = dataset.dataset(cache, partitioning=part, format=cache_format) 

692 if any(name not in exclude and name not in db.schema.names for name in res.columns.to_list()): 

693 if verbose: 

694 print( 

695 "clearing cache since it contains columns not in new results", 

696 f"({perf_counter() - start_time:.4f})", 

697 ) 

698 shutil.rmtree(cache, True) 

699 initialize_cache() 

700 db = dataset.dataset(cache, partitioning=part, format=cache_format) 

701 fresh = res[~res.duplicated(subset=["text_hash"])] 

702 su = db.filter(compute.field("text_hash").isin(fresh["text_hash"])) 

703 if su.count_rows() > 0: 

704 cached = ~compute.is_in( 

705 fresh["text_hash"].to_list(), 

706 su.scanner(columns=["text_hash"]).to_table()["text_hash"], 

707 ).to_pandas(split_blocks=True, self_destruct=True) 

708 if any(cached): 

709 fresh = fresh[cached.to_list()] 

710 else: 

711 return 

712 n_new = len(fresh) 

713 if n_new: 

714 if verbose: 

715 print( 

716 f"adding {n_new} result{'' if n_new == 1 else 's'}", 

717 f"to cache ({perf_counter() - start_time:.4f})", 

718 ) 

719 dataset.write_dataset( 

720 pyarrow.Table.from_pandas( 

721 fresh.drop( 

722 list(exclude.intersection(fresh.columns)), 

723 axis="columns", 

724 ) 

725 ), 

726 cache, 

727 partitioning=part, 

728 format=cache_format, 

729 existing_data_behavior="overwrite_or_ignore", 

730 ) 

731 

732 

733def _manage_request_cache(): 

734 os.makedirs(REQUEST_CACHE, exist_ok=True) 

735 try: 

736 refreshed = time() 

737 log_file = REQUEST_CACHE + "log.txt" 

738 if os.path.exists(log_file): 

739 with open(log_file, encoding="utf-8") as log: 

740 logged = log.readline() 

741 if isinstance(logged, list): 

742 logged = logged[0] 

743 refreshed = float(logged) 

744 else: 

745 with open(log_file, "w", encoding="utf-8") as log: 

746 log.write(str(time())) 

747 if time() - refreshed > 86400: 

748 for cached_request in glob(REQUEST_CACHE + "*.json"): 

749 os.remove(cached_request) 

750 except Exception as exc: 

751 msg = "failed to manage request cache" 

752 raise RuntimeWarning(msg) from exc 

753 

754 

755def _readin( 

756 paths: List[str], 

757 text_column: Union[str, None], 

758 id_column: Union[str, None], 

759 collapse_lines: bool, 

760 encoding: Union[str, None], 

761) -> Union[List[str], pandas.DataFrame]: 

762 text = [] 

763 ids = [] 

764 sel = [] 

765 if text_column is not None: 

766 sel.append(text_column) 

767 if id_column is not None: 

768 sel.append(id_column) 

769 enc = encoding 

770 predict_encoding = enc is None 

771 if predict_encoding: 

772 detect = UniversalDetector() 

773 

774 def handle_encoding(file: str): 

775 detect.reset() 

776 with open(file, "rb") as text: 

777 detect.feed(text.readline(5)) 

778 return detect.close()["encoding"] 

779 

780 if os.path.splitext(paths[0])[1] == ".txt" and not sel: 

781 if collapse_lines: 

782 for file in paths: 

783 if predict_encoding: 

784 enc = handle_encoding(file) 

785 with open(file, encoding=enc, errors="ignore") as texts: 

786 text.append(" ".join([line.rstrip() for line in texts])) 

787 else: 

788 for file in paths: 

789 if predict_encoding: 

790 enc = handle_encoding(file) 

791 with open(file, encoding=enc, errors="ignore") as texts: 

792 lines = [line.rstrip() for line in texts] 

793 text += lines 

794 ids += ( 

795 [file] 

796 if len(lines) == 1 

797 else [file + str(i + 1) for i in range(len(lines))] 

798 ) 

799 return pandas.DataFrame({"text": text, "ids": ids}) 

800 elif collapse_lines: 

801 for file in paths: 

802 if predict_encoding: 

803 enc = handle_encoding(file) 

804 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

805 text.append(" ".join(temp[text_column])) 

806 else: 

807 for file in paths: 

808 if predict_encoding: 

809 enc = handle_encoding(file) 

810 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

811 if text_column not in temp: 

812 msg = f"`text_column` ({text_column}) was not found in all files" 

813 raise IndexError(msg) 

814 text += temp[text_column].to_list() 

815 ids += ( 

816 temp[id_column].to_list() 

817 if id_column is not None 

818 else [file] 

819 if len(temp) == 1 

820 else [file + str(i + 1) for i in range(len(temp))] 

821 ) 

822 return pandas.DataFrame({"text": text, "ids": ids}) 

823 return text