Coverage for src/receptiviti/manage_request.py: 80%

397 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2025-01-30 13:56 -0500

1"""Make requests to the API.""" 

2 

3import hashlib 

4import json 

5import math 

6import os 

7import pickle 

8import re 

9import sys 

10import urllib.parse 

11import warnings 

12from glob import glob 

13from multiprocessing import Process, Queue, current_process 

14from tempfile import TemporaryDirectory, gettempdir 

15from time import perf_counter, sleep, time 

16from typing import List, Union, Tuple 

17 

18import numpy 

19import pandas 

20import pyarrow 

21import pyarrow.compute 

22import pyarrow.dataset 

23import pyarrow.feather 

24import pyarrow.parquet 

25import requests 

26from chardet.universaldetector import UniversalDetector 

27from tqdm import tqdm 

28 

29from receptiviti.status import _resolve_request_def, status 

30 

31CACHE = gettempdir() + "/receptiviti_cache/" 

32REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/" 

33 

34 

35def _manage_request( 

36 text: Union[str, List[str], pandas.DataFrame, None] = None, 

37 ids: Union[str, List[str], List[int], None] = None, 

38 text_column: Union[str, None] = None, 

39 id_column: Union[str, None] = None, 

40 files: Union[List[str], None] = None, 

41 directory: Union[str, None] = None, 

42 file_type="txt", 

43 encoding: Union[str, None] = None, 

44 context="written", 

45 api_args: Union[dict, None] = None, 

46 bundle_size=1000, 

47 bundle_byte_limit=75e5, 

48 collapse_lines=False, 

49 retry_limit=50, 

50 request_cache=True, 

51 cores=1, 

52 collect_results=True, 

53 in_memory: Union[bool, None] = None, 

54 verbose=False, 

55 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"), 

56 make_request=True, 

57 text_as_paths=False, 

58 dotenv: Union[bool, str] = True, 

59 cache=os.getenv("RECEPTIVITI_CACHE", ""), 

60 cache_overwrite=False, 

61 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet"), 

62 key=os.getenv("RECEPTIVITI_KEY", ""), 

63 secret=os.getenv("RECEPTIVITI_SECRET", ""), 

64 url=os.getenv("RECEPTIVITI_URL", ""), 

65 version=os.getenv("RECEPTIVITI_VERSION", ""), 

66 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""), 

67 to_norming=False, 

68) -> Tuple[pandas.DataFrame, Union[pandas.DataFrame, None], bool]: 

69 if cores > 1 and current_process().name != "MainProcess": 

70 return (pandas.DataFrame(), None, False) 

71 start_time = perf_counter() 

72 

73 if request_cache: 

74 if verbose: 

75 print(f"preparing request cache ({perf_counter() - start_time:.4f})") 

76 _manage_request_cache() 

77 

78 # resolve credentials and check status 

79 full_url, url, key, secret = _resolve_request_def(url, key, secret, dotenv) 

80 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", full_url) 

81 if url_parts: 

82 from_url = url_parts.groups() 

83 if not version and from_url[0] is not None: 

84 version = from_url[0] 

85 if not endpoint and from_url[1] is not None: 

86 endpoint = from_url[1] 

87 if to_norming: 

88 version = "v2" 

89 endpoint = "norming" 

90 request_cache = False 

91 else: 

92 if not version: 

93 version = os.getenv("RECEPTIVITI_VERSION", "v1") 

94 version = version.lower() 

95 if not version or not re.search("^v\\d+$", version): 

96 msg = f"invalid version: {version}" 

97 raise RuntimeError(msg) 

98 if not endpoint: 

99 endpoint_default = "framework" if version == "v1" else "analyze" 

100 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default) 

101 endpoint = re.sub("^.*/", "", endpoint).lower() 

102 if not endpoint or re.search("[^a-z]", endpoint): 

103 msg = f"invalid endpoint: {endpoint}" 

104 raise RuntimeError(msg) 

105 api_status = status(url, key, secret, dotenv, verbose=False) 

106 if api_status is None or api_status.status_code != 200: 

107 msg = ( 

108 "URL is not reachable" 

109 if api_status is None 

110 else f"API status failed: {api_status.status_code}: {api_status.reason}" 

111 ) 

112 raise RuntimeError(msg) 

113 

114 # resolve text and ids 

115 text_as_dir = False 

116 if text is None: 

117 if directory is not None: 

118 text = directory 

119 text_as_dir = True 

120 elif files is not None: 

121 text_as_paths = True 

122 text = files 

123 else: 

124 msg = "enter text as the first argument, or use the `files` or `directory` arguments" 

125 raise RuntimeError(msg) 

126 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260): 

127 if not text_as_dir and os.path.isfile(text): 

128 if verbose: 

129 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})") 

130 text = _readin([text], text_column, id_column, collapse_lines, encoding) 

131 if isinstance(text, pandas.DataFrame): 

132 id_column = "ids" 

133 text_column = "text" 

134 text_as_paths = False 

135 elif os.path.isdir(text): 

136 text = glob(f"{text}/*{file_type}") 

137 text_as_paths = True 

138 elif os.path.isdir(os.path.dirname(text)): 

139 msg = f"`text` appears to point to a directory, but it does not exist: {text}" 

140 raise RuntimeError(msg) 

141 if isinstance(text, pandas.DataFrame): 

142 if id_column is not None: 

143 if id_column in text: 

144 ids = text[id_column].to_list() 

145 else: 

146 msg = f"`id_column` ({id_column}) is not in `text`" 

147 raise IndexError(msg) 

148 if text_column is not None: 

149 if text_column in text: 

150 text = text[text_column].to_list() 

151 else: 

152 msg = f"`text_column` ({text_column}) is not in `text`" 

153 raise IndexError(msg) 

154 else: 

155 msg = "`text` is a DataFrame, but no `text_column` is specified" 

156 raise RuntimeError(msg) 

157 if isinstance(text, str): 

158 text = [text] 

159 text_is_path = all(isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text) 

160 if text_as_paths and not text_is_path: 

161 msg = "`text` treated as a list of files, but not all of the entries exist" 

162 raise RuntimeError(msg) 

163 if text_is_path and not collapse_lines: 

164 ids = text 

165 text = _readin(text, text_column, id_column, collapse_lines, encoding) 

166 if isinstance(text, pandas.DataFrame): 

167 if id_column is None: 

168 ids = text["ids"].to_list() 

169 elif id_column in text: 

170 ids = text[id_column].to_list() 

171 text = text["text"].to_list() 

172 text_is_path = False 

173 if ids is None and text_is_path: 

174 ids = text 

175 

176 id_specified = ids is not None 

177 if ids is None: 

178 ids = numpy.arange(1, len(text) + 1).tolist() 

179 elif len(ids) != len(text): 

180 msg = "`ids` is not the same length as `text`" 

181 raise RuntimeError(msg) 

182 original_ids = set(ids) 

183 if len(ids) != len(original_ids): 

184 msg = "`ids` contains duplicates" 

185 raise RuntimeError(msg) 

186 

187 # prepare bundles 

188 if verbose: 

189 print(f"preparing text ({perf_counter() - start_time:.4f})") 

190 data = pandas.DataFrame({"text": text, "id": ids}) 

191 data_subset = data[~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna())] 

192 n_texts = len(data_subset) 

193 if not n_texts: 

194 msg = "no valid texts to process" 

195 raise RuntimeError(msg) 

196 bundle_size = max(1, bundle_size) 

197 n_bundles = math.ceil(n_texts / min(1000, bundle_size)) 

198 groups = data_subset.groupby( 

199 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts], 

200 group_keys=False, 

201 ) 

202 bundles = [] 

203 for _, group in groups: 

204 if sys.getsizeof(group) > bundle_byte_limit: 

205 start = current = end = 0 

206 for txt in group["text"]: 

207 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt) 

208 if size > bundle_byte_limit: 

209 msg = f"one of your texts is over the bundle size limit ({bundle_byte_limit / 1e6} MB)" 

210 raise RuntimeError(msg) 

211 if (current + size) > bundle_byte_limit: 

212 bundles.append(group.iloc[start:end]) 

213 start = end 

214 current = size 

215 else: 

216 current += size 

217 end += 1 

218 bundles.append(group.iloc[start:]) 

219 else: 

220 bundles.append(group) 

221 n_bundles = len(bundles) 

222 if verbose: 

223 print( 

224 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in " 

225 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}", 

226 f"({perf_counter() - start_time:.4f})", 

227 ) 

228 

229 # process bundles 

230 opts = { 

231 "url": ( 

232 full_url 

233 if to_norming 

234 else ( 

235 f"{url}/{version}/{endpoint}/bulk" if version == "v1" else f"{url}/{version}/{endpoint}/{context}" 

236 ).lower() 

237 ), 

238 "version": version, 

239 "auth": requests.auth.HTTPBasicAuth(key, secret), 

240 "retries": retry_limit, 

241 "add": {} if api_args is None else api_args, 

242 "request_cache": request_cache, 

243 "cache": cache, 

244 "cache_overwrite": cache_overwrite, 

245 "cache_format": cache_format, 

246 "to_norming": to_norming, 

247 "make_request": make_request, 

248 "text_is_path": text_is_path, 

249 "text_column": text_column, 

250 "id_column": id_column, 

251 "collapse_lines": collapse_lines, 

252 "encoding": encoding, 

253 "collect_results": collect_results, 

254 } 

255 if version != "v1" and api_args: 

256 opts["url"] += "?" + urllib.parse.urlencode(api_args) 

257 opts["add_hash"] = hashlib.md5( 

258 json.dumps( 

259 {**opts["add"], "url": opts["url"], "key": key, "secret": secret}, 

260 separators=(",", ":"), 

261 ).encode() 

262 ).hexdigest() 

263 if isinstance(progress_bar, str): 

264 progress_bar = progress_bar == "True" 

265 use_pb = (verbose and progress_bar) or progress_bar 

266 parallel = n_bundles > 1 and cores > 1 

267 if in_memory is None: 

268 in_memory = not parallel 

269 with TemporaryDirectory() as scratch_cache: 

270 if not in_memory: 

271 if verbose: 

272 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})") 

273 

274 def write_to_scratch(i: int, bundle: pandas.DataFrame): 

275 temp = f"{scratch_cache}/{i}.json" 

276 with open(temp, "wb") as scratch: 

277 pickle.dump(bundle, scratch, -1) 

278 return temp 

279 

280 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)] 

281 if parallel: 

282 if verbose: 

283 print(f"requesting in parallel ({perf_counter() - start_time:.4f})") 

284 waiter: "Queue[List[Union[pandas.DataFrame, None]]]" = Queue() 

285 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]" = Queue() 

286 manager = Process( 

287 target=_queue_manager, 

288 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose), 

289 ) 

290 manager.start() 

291 nb = math.ceil(n_bundles / min(n_bundles, cores)) 

292 cores = math.ceil(n_bundles / nb) 

293 procs = [ 

294 Process( 

295 target=_process, 

296 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue), 

297 ) 

298 for i in range(cores) 

299 ] 

300 for cl in procs: 

301 cl.start() 

302 res = waiter.get() 

303 else: 

304 if verbose: 

305 print(f"requesting serially ({perf_counter() - start_time:.4f})") 

306 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None 

307 res = _process(bundles, opts, pb=pb) 

308 if pb is not None: 

309 pb.close() 

310 if verbose: 

311 print(f"done requesting ({perf_counter() - start_time:.4f})") 

312 

313 return (data, pandas.concat(res, ignore_index=True, sort=False) if opts["collect_results"] else None, id_specified) 

314 

315 

316def _queue_manager( 

317 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]", 

318 waiter: "Queue[List[Union[pandas.DataFrame, None]]]", 

319 n_texts: int, 

320 n_bundles: int, 

321 use_pb=True, 

322 verbose=False, 

323): 

324 if use_pb: 

325 pb = tqdm(total=n_texts, leave=verbose) 

326 res: List[Union[pandas.DataFrame, None]] = [] 

327 for size, chunk in iter(queue.get, None): 

328 if size: 

329 if use_pb: 

330 pb.update(size) 

331 res.append(chunk) 

332 if len(res) >= n_bundles: 

333 break 

334 else: 

335 break 

336 waiter.put(res) 

337 

338 

339def _process( 

340 bundles: List[pandas.DataFrame], 

341 opts: dict, 

342 queue: Union["Queue[tuple[int, Union[pandas.DataFrame, None]]]", None] = None, 

343 pb: Union[tqdm, None] = None, 

344) -> List[Union[pandas.DataFrame, None]]: 

345 reses: List[Union[pandas.DataFrame, None]] = [] 

346 for bundle in bundles: 

347 if isinstance(bundle, str): 

348 with open(bundle, "rb") as scratch: 

349 bundle = pickle.load(scratch) 

350 body = [] 

351 bundle.insert(0, "text_hash", "") 

352 if opts["text_is_path"]: 

353 bundle["text"] = _readin( 

354 bundle["text"].to_list(), 

355 opts["text_column"], 

356 opts["id_column"], 

357 opts["collapse_lines"], 

358 opts["encoding"], 

359 ) 

360 for i, text in enumerate(bundle["text"]): 

361 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest() 

362 bundle.iat[i, 0] = text_hash 

363 if opts["version"] == "v1": 

364 body.append({"content": text, "request_id": text_hash, **opts["add"]}) 

365 else: 

366 body.append({"text": text, "request_id": text_hash}) 

367 ncached = 0 

368 cached: Union[pandas.DataFrame, None] = None 

369 cached_cols: List[str] = [] 

370 if not opts["cache_overwrite"] and opts["cache"] and os.listdir(opts["cache"]): 

371 db = pyarrow.dataset.dataset( 

372 opts["cache"], 

373 partitioning=pyarrow.dataset.partitioning( 

374 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive" 

375 ), 

376 format=opts["cache_format"], 

377 ) 

378 cached_cols = db.schema.names 

379 if "text_hash" in cached_cols: 

380 su = db.filter(pyarrow.compute.field("text_hash").isin(bundle["text_hash"])) 

381 ncached = su.count_rows() 

382 if ncached > 0: 

383 cached = ( 

384 su.to_table().to_pandas(split_blocks=True, self_destruct=True) 

385 if opts["collect_results"] 

386 else su.scanner(["text_hash"]).to_table().to_pandas(split_blocks=True, self_destruct=True) 

387 ) 

388 res: Union[str, pandas.DataFrame] = "failed to retrieve results" 

389 json_body = json.dumps(body, separators=(",", ":")) 

390 bundle_hash = hashlib.md5(json_body.encode()).hexdigest() 

391 if cached is None or ncached < len(bundle): 

392 if cached is None: 

393 res = _prepare_results(json_body, bundle_hash, opts) 

394 else: 

395 fresh = ~bundle["text_hash"].isin(cached["text_hash"]) 

396 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":")) 

397 res = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts) 

398 if not isinstance(res, str): 

399 if ncached: 

400 if res.ndim != len(cached_cols) or not pandas.Series(cached_cols).isin(res.columns).all(): 

401 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":")) 

402 cached = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts) 

403 if cached is not None and opts["collect_results"]: 

404 res = pandas.concat([res, cached]) 

405 if opts["cache"]: 

406 writer = _get_writer(opts["cache_format"]) 

407 schema = pyarrow.schema( 

408 ( 

409 col, 

410 ( 

411 pyarrow.string() 

412 if res[col].dtype == "O" 

413 else ( 

414 pyarrow.int32() 

415 if col in ["summary.word_count", "summary.sentence_count"] 

416 else pyarrow.float32() 

417 ) 

418 ), 

419 ) 

420 for col in res.columns 

421 if col not in ["id", "bin", *(opts["add"].keys() if opts["add"] else [])] 

422 ) 

423 for id_bin, d in res.groupby("bin"): 

424 bin_dir = f"{opts['cache']}/bin={id_bin}" 

425 os.makedirs(bin_dir, exist_ok=True) 

426 writer( 

427 pyarrow.Table.from_pandas(d, schema, preserve_index=False), 

428 f"{bin_dir}/fragment-{bundle_hash}-0.{opts['cache_format']}", 

429 ) 

430 else: 

431 res = cached 

432 nres = len(res) 

433 if not opts["collect_results"]: 

434 reses.append(None) 

435 elif not isinstance(res, str): 

436 if "text_hash" in res: 

437 res = res.merge(bundle[["text_hash", "id"]], on="text_hash") 

438 reses.append(res) 

439 if queue is not None: 

440 queue.put((0, None) if isinstance(res, str) else (nres + ncached, res)) 

441 elif pb is not None: 

442 pb.update(len(bundle)) 

443 if isinstance(res, str): 

444 raise RuntimeError(res) 

445 return reses 

446 

447 

448def _prepare_results(body: str, bundle_hash: str, opts: dict): 

449 raw_res = _request( 

450 body, 

451 opts["url"], 

452 opts["auth"], 

453 opts["retries"], 

454 REQUEST_CACHE + bundle_hash + ".json" if opts["request_cache"] else "", 

455 opts["to_norming"], 

456 opts["make_request"], 

457 ) 

458 if isinstance(raw_res, str): 

459 return raw_res 

460 res = pandas.json_normalize(raw_res) 

461 if "request_id" in res: 

462 res.rename(columns={"request_id": "text_hash"}, inplace=True) 

463 res.drop( 

464 list({"response_id", "language", "version", "error"}.intersection(res.columns)), 

465 axis="columns", 

466 inplace=True, 

467 ) 

468 res.insert(res.ndim, "bin", ["h" + h[0] for h in res["text_hash"]]) 

469 return res 

470 

471 

472def _request( 

473 body: str, 

474 url: str, 

475 auth: requests.auth.HTTPBasicAuth, 

476 retries: int, 

477 cache="", 

478 to_norming=False, 

479 execute=True, 

480) -> Union[dict, str]: 

481 if not os.path.isfile(cache): 

482 if not execute: 

483 return "`make_request` is `False`, but there are texts with no cached results" 

484 if to_norming: 

485 res = requests.patch(url, body, auth=auth, timeout=9999) 

486 else: 

487 res = requests.post(url, body, auth=auth, timeout=9999) 

488 if cache and res.status_code == 200: 

489 with open(cache, "w", encoding="utf-8") as response: 

490 json.dump(res.json(), response) 

491 else: 

492 with open(cache, encoding="utf-8") as response: 

493 data = json.load(response) 

494 return data["results"] if "results" in data else data 

495 data = res.json() 

496 if res.status_code == 200: 

497 data = dict(data[0] if isinstance(data, list) else data) 

498 return data["results"] if "results" in data else data 

499 if os.path.isfile(cache): 

500 os.remove(cache) 

501 if retries > 0 and "code" in data and data["code"] == 1420: 

502 cd = re.search( 

503 "[0-9]+(?:\\.[0-9]+)?", 

504 (res.json()["message"] if res.headers["Content-Type"] == "application/json" else res.text), 

505 ) 

506 sleep(1 if cd is None else float(cd[0]) / 1e3) 

507 return _request(body, url, auth, retries - 1, cache, to_norming) 

508 return f"request failed, and have no retries: {res.status_code}: {data['error'] if 'error' in data else res.reason}" 

509 

510 

511def _manage_request_cache(): 

512 os.makedirs(REQUEST_CACHE, exist_ok=True) 

513 try: 

514 refreshed = time() 

515 log_file = REQUEST_CACHE + "log.txt" 

516 if os.path.exists(log_file): 

517 with open(log_file, encoding="utf-8") as log: 

518 logged = log.readline() 

519 if isinstance(logged, list): 

520 logged = logged[0] 

521 refreshed = float(logged) 

522 else: 

523 with open(log_file, "w", encoding="utf-8") as log: 

524 log.write(str(time())) 

525 if time() - refreshed > 86400: 

526 for cached_request in glob(REQUEST_CACHE + "*.json"): 

527 os.remove(cached_request) 

528 except Exception as exc: 

529 warnings.warn(UserWarning(f"failed to manage request cache: {exc}"), stacklevel=2) 

530 

531 

532def _readin( 

533 paths: List[str], 

534 text_column: Union[str, None], 

535 id_column: Union[str, None], 

536 collapse_lines: bool, 

537 encoding: Union[str, None], 

538) -> Union[List[str], pandas.DataFrame]: 

539 text = [] 

540 ids = [] 

541 sel = [] 

542 if text_column is not None: 

543 sel.append(text_column) 

544 if id_column is not None: 

545 sel.append(id_column) 

546 enc = encoding 

547 predict_encoding = enc is None 

548 if predict_encoding: 

549 detect = UniversalDetector() 

550 

551 def handle_encoding(file: str): 

552 detect.reset() 

553 with open(file, "rb") as text: 

554 while True: 

555 chunk = text.read(1024) 

556 if not chunk: 

557 break 

558 detect.feed(chunk) 

559 if detect.done: 

560 break 

561 detected = detect.close()["encoding"] 

562 if detected is None: 

563 msg = "failed to detect encoding; please specify with the `encoding` argument" 

564 raise RuntimeError(msg) 

565 return detected 

566 

567 if os.path.splitext(paths[0])[1] == ".txt" and not sel: 

568 if collapse_lines: 

569 for file in paths: 

570 if predict_encoding: 

571 enc = handle_encoding(file) 

572 with open(file, encoding=enc, errors="ignore") as texts: 

573 text.append(" ".join([line.rstrip() for line in texts])) 

574 else: 

575 for file in paths: 

576 if predict_encoding: 

577 enc = handle_encoding(file) 

578 with open(file, encoding=enc, errors="ignore") as texts: 

579 lines = [line.rstrip() for line in texts] 

580 text += lines 

581 ids += [file] if len(lines) == 1 else [file + str(i + 1) for i in range(len(lines))] 

582 return pandas.DataFrame({"text": text, "ids": ids}) 

583 elif collapse_lines: 

584 for file in paths: 

585 if predict_encoding: 

586 enc = handle_encoding(file) 

587 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

588 text.append(" ".join(temp[text_column])) 

589 else: 

590 for file in paths: 

591 if predict_encoding: 

592 enc = handle_encoding(file) 

593 temp = pandas.read_csv(file, encoding=enc, usecols=sel) 

594 if text_column not in temp: 

595 msg = f"`text_column` ({text_column}) was not found in all files" 

596 raise IndexError(msg) 

597 text += temp[text_column].to_list() 

598 ids += ( 

599 temp[id_column].to_list() 

600 if id_column is not None 

601 else [file] if len(temp) == 1 else [file + str(i + 1) for i in range(len(temp))] 

602 ) 

603 return pandas.DataFrame({"text": text, "ids": ids}) 

604 return text 

605 

606 

607def _get_writer(write_format: str): 

608 if write_format == "parquet": 

609 return pyarrow.parquet.write_table 

610 if write_format == "feather": 

611 return pyarrow.feather.write_feather