Coverage for src/receptiviti/manage_request.py: 81%
417 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 09:09 -0400
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 09:09 -0400
1"""Make requests to the API."""
3import hashlib
4import json
5import math
6import os
7import pickle
8import re
9import sys
10import urllib.parse
11import warnings
12from glob import glob
13from multiprocessing import Process, Queue, current_process
14from tempfile import TemporaryDirectory, gettempdir
15from time import perf_counter, sleep, time
16from typing import Dict, List, Literal, Tuple, TypedDict, Union
18import numpy
19import pandas
20import pyarrow
21import pyarrow.compute
22import pyarrow.dataset
23import pyarrow.feather
24import pyarrow.parquet
25import requests
26import requests.auth
27from chardet.universaldetector import UniversalDetector
28from tqdm import tqdm
30from receptiviti.status import _resolve_request_def, status
32CACHE = gettempdir() + "/receptiviti_cache/"
33REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/"
36class Options(TypedDict):
37 url: str
38 version: str
39 auth: requests.auth.HTTPBasicAuth
40 retries: int
41 add: Dict[str, Union[str, List[str]]]
42 request_cache: bool
43 cache: str
44 cache_overwrite: bool
45 cache_format: Literal["parquet", "feather"]
46 to_norming: bool
47 make_request: bool
48 text_is_path: bool
49 text_column: Union[str, None]
50 id_column: Union[str, None]
51 collapse_lines: bool
52 encoding: Union[str, None]
53 collect_results: bool
54 add_hash: str
57def _manage_request(
58 text: Union[str, List[str], pandas.DataFrame, None] = None,
59 ids: Union[str, List[str], List[int], None] = None,
60 text_column: Union[str, None] = None,
61 id_column: Union[str, None] = None,
62 files: Union[List[str], None] = None,
63 directory: Union[str, None] = None,
64 file_type: str = "txt",
65 encoding: Union[str, None] = None,
66 context: str = "written",
67 api_args: Union[Dict[str, Union[str, List[str]]], None] = None,
68 bundle_size: int = 1000,
69 bundle_byte_limit: float = 75e5,
70 collapse_lines: bool = False,
71 retry_limit: int = 50,
72 request_cache: bool = True,
73 cores: int = 1,
74 collect_results: bool = True,
75 in_memory: Union[bool, None] = None,
76 verbose: bool = False,
77 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"),
78 make_request: bool = True,
79 text_as_paths: bool = False,
80 dotenv: Union[bool, str] = True,
81 cache: str = os.getenv("RECEPTIVITI_CACHE", ""),
82 cache_overwrite: bool = False,
83 cache_format: str = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet"),
84 key: str = os.getenv("RECEPTIVITI_KEY", ""),
85 secret: str = os.getenv("RECEPTIVITI_SECRET", ""),
86 url: str = os.getenv("RECEPTIVITI_URL", ""),
87 version: str = os.getenv("RECEPTIVITI_VERSION", ""),
88 endpoint: str = os.getenv("RECEPTIVITI_ENDPOINT", ""),
89 to_norming: bool = False,
90) -> Tuple[pandas.DataFrame, Union[pandas.DataFrame, None], bool]:
91 if cores > 1 and current_process().name != "MainProcess":
92 return (pandas.DataFrame(), None, False)
93 start_time = perf_counter()
95 if request_cache:
96 if verbose:
97 print(f"preparing request cache ({perf_counter() - start_time:.4f})")
98 _manage_request_cache()
100 # resolve credentials and check status
101 full_url, url, key, secret = _resolve_request_def(url, key, secret, dotenv)
102 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", full_url)
103 if url_parts:
104 from_url = url_parts.groups()
105 if not version and from_url[0] is not None:
106 version = from_url[0]
107 if not endpoint and from_url[1] is not None:
108 endpoint = from_url[1]
109 if to_norming:
110 version = "v2"
111 endpoint = "norming"
112 request_cache = False
113 else:
114 if not version:
115 version = os.getenv("RECEPTIVITI_VERSION", "v1")
116 version = version.lower()
117 if not version or not re.search("^v\\d+$", version):
118 msg = f"invalid version: {version}"
119 raise RuntimeError(msg)
120 if not endpoint:
121 endpoint_default = "framework" if version == "v1" else "analyze"
122 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default)
123 endpoint = re.sub("^.*/", "", endpoint).lower()
124 if not endpoint or re.search("[^a-z]", endpoint):
125 msg = f"invalid endpoint: {endpoint}"
126 raise RuntimeError(msg)
127 api_status = status(url, key, secret, dotenv, verbose=False)
128 if api_status is None or api_status.status_code != 200:
129 msg = (
130 "URL is not reachable"
131 if api_status is None
132 else f"API status failed: {api_status.status_code}: {api_status.reason}"
133 )
134 raise RuntimeError(msg)
136 # resolve text and ids
137 text_as_dir = False
138 if text is None:
139 if directory is not None:
140 text = directory
141 text_as_dir = True
142 elif files is not None:
143 text_as_paths = True
144 text = files
145 else:
146 msg = "enter text as the first argument, or use the `files` or `directory` arguments"
147 raise RuntimeError(msg)
148 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260):
149 if not text_as_dir and os.path.isfile(text):
150 if verbose:
151 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})")
152 text = _readin([text], text_column, id_column, collapse_lines, encoding)
153 if isinstance(text, pandas.DataFrame):
154 id_column = "ids"
155 text_column = "text"
156 text_as_paths = False
157 elif os.path.isdir(text):
158 text = glob(f"{text}/*{file_type}")
159 text_as_paths = True
160 elif os.path.isdir(os.path.dirname(text)):
161 msg = f"`text` appears to point to a directory, but it does not exist: {text}"
162 raise RuntimeError(msg)
163 if isinstance(text, pandas.DataFrame):
164 if id_column is not None:
165 if id_column in text:
166 ids = text[id_column].to_list()
167 else:
168 msg = f"`id_column` ({id_column}) is not in `text`"
169 raise IndexError(msg)
170 if text_column is not None:
171 if text_column in text:
172 text = text[text_column].to_list()
173 else:
174 msg = f"`text_column` ({text_column}) is not in `text`"
175 raise IndexError(msg)
176 else:
177 msg = "`text` is a DataFrame, but no `text_column` is specified"
178 raise RuntimeError(msg)
179 if isinstance(text, str):
180 text = [text]
181 text_is_path = all(isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text)
182 if text_as_paths and not text_is_path:
183 msg = "`text` treated as a list of files, but not all of the entries exist"
184 raise RuntimeError(msg)
185 if text_is_path and not collapse_lines:
186 ids = text
187 text = _readin(text, text_column, id_column, collapse_lines, encoding)
188 if isinstance(text, pandas.DataFrame):
189 if id_column is None:
190 ids = text["ids"].to_list()
191 elif id_column in text:
192 ids = text[id_column].to_list()
193 text = text["text"].to_list()
194 text_is_path = False
195 if ids is None and text_is_path:
196 ids = text
198 id_specified = ids is not None
199 if ids is None:
200 ids = numpy.arange(1, len(text) + 1).tolist()
201 elif len(ids) != len(text):
202 msg = "`ids` is not the same length as `text`"
203 raise RuntimeError(msg)
204 original_ids = set(ids)
205 if len(ids) != len(original_ids):
206 msg = "`ids` contains duplicates"
207 raise RuntimeError(msg)
209 # prepare bundles
210 if verbose:
211 print(f"preparing text ({perf_counter() - start_time:.4f})")
212 data = pandas.DataFrame({"text": text, "id": ids})
213 data_subset = data[~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna())]
214 n_texts = len(data_subset)
215 if not n_texts:
216 msg = "no valid texts to process"
217 raise RuntimeError(msg)
218 bundle_size = max(1, bundle_size)
219 n_bundles = math.ceil(n_texts / min(1000, bundle_size))
220 groups = data_subset.groupby(
221 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts],
222 group_keys=False,
223 )
224 bundles: Union[List[pandas.DataFrame], List[pandas.DataFrame]] = []
225 for _, group in groups:
226 if sys.getsizeof(group) > bundle_byte_limit:
227 start = current = end = 0
228 for txt in group["text"]:
229 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt)
230 if size > bundle_byte_limit:
231 msg = f"one of your texts is over the bundle size limit ({bundle_byte_limit / 1e6} MB)"
232 raise RuntimeError(msg)
233 if (current + size) > bundle_byte_limit:
234 bundles.append(group.iloc[start:end])
235 start = end
236 current = size
237 else:
238 current += size
239 end += 1
240 bundles.append(group.iloc[start:])
241 else:
242 bundles.append(group)
243 n_bundles = len(bundles)
244 if verbose:
245 print(
246 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in "
247 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}",
248 f"({perf_counter() - start_time:.4f})",
249 )
251 # process bundles
252 opts: Options = {
253 "url": (
254 full_url
255 if to_norming
256 else (
257 f"{url}/{version}/{endpoint}/bulk" if version == "v1" else f"{url}/{version}/{endpoint}/{context}"
258 ).lower()
259 ),
260 "version": version,
261 "auth": requests.auth.HTTPBasicAuth(key, secret),
262 "retries": retry_limit,
263 "add": {} if api_args is None else api_args,
264 "request_cache": request_cache,
265 "cache": cache,
266 "cache_overwrite": cache_overwrite,
267 "cache_format": "feather" if cache_format == "feather" else "parquet",
268 "to_norming": to_norming,
269 "make_request": make_request,
270 "text_is_path": text_is_path,
271 "text_column": text_column,
272 "id_column": id_column,
273 "collapse_lines": collapse_lines,
274 "encoding": encoding,
275 "collect_results": collect_results,
276 "add_hash": "",
277 }
278 if version != "v1" and api_args:
279 opts["url"] += "?" + urllib.parse.urlencode(api_args)
280 opts["add_hash"] = hashlib.md5(
281 json.dumps(
282 {**opts["add"], "url": opts["url"], "key": key, "secret": secret},
283 separators=(",", ":"),
284 ).encode()
285 ).hexdigest()
286 if isinstance(progress_bar, str):
287 progress_bar = progress_bar == "True"
288 use_pb = (verbose and progress_bar) or progress_bar
289 parallel = n_bundles > 1 and cores > 1
290 if in_memory is None:
291 in_memory = not parallel
292 with TemporaryDirectory() as scratch_cache:
293 if not in_memory:
294 if verbose:
295 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})")
297 def write_to_scratch(i: int, bundle: pandas.DataFrame):
298 temp = f"{scratch_cache}/{i}.json"
299 with open(temp, "wb") as scratch:
300 pickle.dump(bundle, scratch, -1)
301 return temp
303 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)]
304 if parallel:
305 if verbose:
306 print(f"requesting in parallel ({perf_counter() - start_time:.4f})")
307 waiter: "Queue[List[Union[pandas.DataFrame, None]]]" = Queue()
308 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]" = Queue()
309 manager = Process(
310 target=_queue_manager,
311 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose),
312 )
313 manager.start()
314 nb = math.ceil(n_bundles / min(n_bundles, cores))
315 cores = math.ceil(n_bundles / nb)
316 procs = [
317 Process(
318 target=_process,
319 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue),
320 )
321 for i in range(cores)
322 ]
323 for cl in procs:
324 cl.start()
325 res = waiter.get()
326 else:
327 if verbose:
328 print(f"requesting serially ({perf_counter() - start_time:.4f})")
329 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None
330 res = _process(bundles, opts, pb=pb)
331 if pb is not None:
332 pb.close()
333 if verbose:
334 print(f"done requesting ({perf_counter() - start_time:.4f})")
336 return (data, pandas.concat(res, ignore_index=True, sort=False) if opts["collect_results"] else None, id_specified)
339def _queue_manager(
340 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]",
341 waiter: "Queue[List[Union[pandas.DataFrame, None]]]",
342 n_texts: int,
343 n_bundles: int,
344 use_pb: bool = True,
345 verbose: bool = False,
346):
347 if use_pb:
348 pb = tqdm(total=n_texts, leave=verbose)
349 res: List[Union[pandas.DataFrame, None]] = []
350 for size, chunk in iter(queue.get, None):
351 if size:
352 if use_pb:
353 pb.update(size)
354 res.append(chunk)
355 if len(res) >= n_bundles:
356 break
357 else:
358 break
359 waiter.put(res)
362def _process(
363 bundles: List[pandas.DataFrame],
364 opts: Options,
365 queue: Union["Queue[Tuple[int, Union[pandas.DataFrame, None]]]", None] = None,
366 pb: Union[tqdm, None] = None,
367) -> List[Union[pandas.DataFrame, None]]:
368 reses: List[Union[pandas.DataFrame, None]] = []
369 for bundle in bundles:
370 if isinstance(bundle, str):
371 with open(bundle, "rb") as scratch:
372 bundle = pickle.load(scratch)
373 body = []
374 bundle.insert(0, "text_hash", "")
375 if opts["text_is_path"]:
376 bundle["text"] = _readin(
377 bundle["text"].to_list(),
378 opts["text_column"],
379 opts["id_column"],
380 opts["collapse_lines"],
381 opts["encoding"],
382 )
383 for i, text in enumerate(bundle["text"]):
384 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest()
385 bundle.iat[i, 0] = text_hash
386 if opts["version"] == "v1":
387 body.append({"content": text, "request_id": text_hash, **opts["add"]})
388 else:
389 body.append({"text": text, "request_id": text_hash})
390 ncached = 0
391 cached: Union[pandas.DataFrame, None] = None
392 cached_cols: List[str] = []
393 if not opts["cache_overwrite"] and opts["cache"] and os.listdir(opts["cache"]):
394 db = pyarrow.dataset.dataset(
395 opts["cache"],
396 partitioning=pyarrow.dataset.partitioning(
397 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive"
398 ),
399 format=opts["cache_format"],
400 )
401 cached_cols = db.schema.names
402 if "text_hash" in cached_cols:
403 su = db.filter(pyarrow.compute.field("text_hash").isin(bundle["text_hash"]))
404 ncached = su.count_rows()
405 if ncached > 0:
406 cached = (
407 su.to_table().to_pandas(split_blocks=True, self_destruct=True)
408 if opts["collect_results"]
409 else su.scanner(["text_hash"]).to_table().to_pandas(split_blocks=True, self_destruct=True)
410 )
411 res: Union[str, pandas.DataFrame] = "failed to retrieve results"
412 json_body = json.dumps(body, separators=(",", ":"))
413 bundle_hash = hashlib.md5(json_body.encode()).hexdigest()
414 if cached is None or ncached < len(bundle):
415 if cached is None:
416 res = _prepare_results(json_body, bundle_hash, opts)
417 else:
418 fresh = ~bundle["text_hash"].isin(cached["text_hash"])
419 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":"))
420 res = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts)
421 if not isinstance(res, str):
422 if ncached:
423 if res.ndim != len(cached_cols) or not pandas.Series(cached_cols).isin(res.columns).all():
424 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":"))
425 cached = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts)
426 if cached is not None and opts["collect_results"]:
427 res = pandas.concat([res, cached])
428 if opts["cache"]:
429 writer = _get_writer(opts["cache_format"])
430 schema = pyarrow.schema(
431 (
432 col,
433 (
434 pyarrow.string()
435 if res[col].dtype == "O"
436 else (
437 pyarrow.int32()
438 if col in ["summary.word_count", "summary.sentence_count"]
439 else pyarrow.float32()
440 )
441 ),
442 )
443 for col in res.columns
444 if col not in ["id", "bin", *(opts["add"].keys() if opts["add"] else [])]
445 )
446 for id_bin, d in res.groupby("bin"):
447 bin_dir = f"{opts['cache']}/bin={id_bin}"
448 os.makedirs(bin_dir, exist_ok=True)
449 writer(
450 pyarrow.Table.from_pandas(d, schema, preserve_index=False),
451 f"{bin_dir}/fragment-{bundle_hash}-0.{opts['cache_format']}",
452 )
453 else:
454 res = cached
455 nres = len(res)
456 if not opts["collect_results"]:
457 reses.append(None)
458 elif not isinstance(res, str):
459 if "text_hash" in res:
460 res = res.merge(bundle[["text_hash", "id"]], on="text_hash")
461 reses.append(res)
462 if queue is not None:
463 queue.put((0, None) if isinstance(res, str) else (nres + ncached, res))
464 elif pb is not None:
465 pb.update(len(bundle))
466 if isinstance(res, str):
467 raise RuntimeError(res)
468 return reses
471def _prepare_results(body: str, bundle_hash: str, opts: Options):
472 raw_res = _request(
473 body,
474 opts["url"],
475 opts["auth"],
476 opts["retries"],
477 REQUEST_CACHE + bundle_hash + ".json" if opts["request_cache"] else "",
478 opts["to_norming"],
479 opts["make_request"],
480 )
481 if isinstance(raw_res, str):
482 return raw_res
483 res = pandas.json_normalize(raw_res)
484 if "request_id" in res:
485 res.rename(columns={"request_id": "text_hash"}, inplace=True)
486 res.drop(
487 list({"response_id", "language", "version", "error"}.intersection(res.columns)),
488 axis="columns",
489 inplace=True,
490 )
491 res.insert(res.ndim, "bin", ["h" + h[0] for h in res["text_hash"]])
492 return res
495def _request(
496 body: str,
497 url: str,
498 auth: requests.auth.HTTPBasicAuth,
499 retries: int,
500 cache: str = "",
501 to_norming: bool = False,
502 execute: bool = True,
503) -> Union[Dict[str, str], str]:
504 if not os.path.isfile(cache):
505 if not execute:
506 return "`make_request` is `False`, but there are texts with no cached results"
507 if to_norming:
508 res = requests.patch(url, body, auth=auth, timeout=9999)
509 else:
510 res = requests.post(url, body, auth=auth, timeout=9999)
511 if cache and res.status_code == 200:
512 with open(cache, "w", encoding="utf-8") as response:
513 json.dump(res.json(), response)
514 else:
515 with open(cache, encoding="utf-8") as response:
516 data = json.load(response)
517 return data["results"] if "results" in data else data
518 data = res.json()
519 if res.status_code == 200:
520 data = dict(data[0] if isinstance(data, list) else data)
521 return data["results"] if "results" in data else data
522 if os.path.isfile(cache):
523 os.remove(cache)
524 if retries > 0 and "code" in data and data["code"] == 1420:
525 cd = re.search(
526 "[0-9]+(?:\\.[0-9]+)?",
527 (res.json()["message"] if res.headers["Content-Type"] == "application/json" else res.text),
528 )
529 sleep(1 if cd is None else float(cd[0]) / 1e3)
530 return _request(body, url, auth, retries - 1, cache, to_norming)
531 return f"request failed, and have no retries: {res.status_code}: {data['error'] if 'error' in data else res.reason}"
534def _manage_request_cache():
535 os.makedirs(REQUEST_CACHE, exist_ok=True)
536 try:
537 refreshed = time()
538 log_file = REQUEST_CACHE + "log.txt"
539 if os.path.exists(log_file):
540 with open(log_file, encoding="utf-8") as log:
541 logged = log.readline()
542 if isinstance(logged, list):
543 logged = logged[0]
544 refreshed = float(logged)
545 else:
546 with open(log_file, "w", encoding="utf-8") as log:
547 log.write(str(time()))
548 if time() - refreshed > 86400:
549 for cached_request in glob(REQUEST_CACHE + "*.json"):
550 os.remove(cached_request)
551 except Exception as exc:
552 warnings.warn(UserWarning(f"failed to manage request cache: {exc}"), stacklevel=2)
555def _readin(
556 paths: List[str],
557 text_column: Union[str, None],
558 id_column: Union[str, None],
559 collapse_lines: bool,
560 encoding: Union[str, None],
561) -> Union[List[str], pandas.DataFrame]:
562 text = []
563 ids = []
564 sel = []
565 if text_column is not None:
566 sel.append(text_column)
567 if id_column is not None:
568 sel.append(id_column)
569 enc = encoding
570 predict_encoding = enc is None
571 if predict_encoding:
572 detect = UniversalDetector()
574 def handle_encoding(file: str):
575 detect.reset()
576 with open(file, "rb") as text:
577 while True:
578 chunk = text.read(1024)
579 if not chunk:
580 break
581 detect.feed(chunk)
582 if detect.done:
583 break
584 detected = detect.close()["encoding"]
585 if detected is None:
586 msg = "failed to detect encoding; please specify with the `encoding` argument"
587 raise RuntimeError(msg)
588 return detected
590 if os.path.splitext(paths[0])[1] == ".txt" and not sel:
591 if collapse_lines:
592 for file in paths:
593 if predict_encoding:
594 enc = handle_encoding(file)
595 with open(file, encoding=enc, errors="ignore") as texts:
596 text.append(" ".join([line.rstrip() for line in texts]))
597 else:
598 for file in paths:
599 if predict_encoding:
600 enc = handle_encoding(file)
601 with open(file, encoding=enc, errors="ignore") as texts:
602 lines = [line.rstrip() for line in texts]
603 text += lines
604 ids += [file] if len(lines) == 1 else [file + str(i + 1) for i in range(len(lines))]
605 return pandas.DataFrame({"text": text, "ids": ids})
606 elif collapse_lines:
607 for file in paths:
608 if predict_encoding:
609 enc = handle_encoding(file)
610 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
611 text.append(" ".join(temp[text_column]))
612 else:
613 for file in paths:
614 if predict_encoding:
615 enc = handle_encoding(file)
616 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
617 if text_column not in temp:
618 msg = f"`text_column` ({text_column}) was not found in all files"
619 raise IndexError(msg)
620 text += temp[text_column].to_list()
621 ids += (
622 temp[id_column].to_list()
623 if id_column is not None
624 else [file] if len(temp) == 1 else [file + str(i + 1) for i in range(len(temp))]
625 )
626 return pandas.DataFrame({"text": text, "ids": ids})
627 return text
630def _get_writer(write_format: str):
631 if write_format == "parquet":
632 return pyarrow.parquet.write_table
633 if write_format == "feather":
634 return pyarrow.feather.write_feather