Coverage for src/receptiviti/manage_request.py: 80%
390 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-13 04:54 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-13 04:54 -0500
1"""Make requests to the API."""
3import hashlib
4import json
5import math
6import os
7import pickle
8import re
9import sys
10import urllib.parse
11import warnings
12from glob import glob
13from multiprocessing import Process, Queue, current_process
14from tempfile import TemporaryDirectory, gettempdir
15from time import perf_counter, sleep, time
16from typing import List, Union
18import numpy
19import pandas
20import pyarrow
21import pyarrow.compute
22import pyarrow.dataset
23import pyarrow.feather
24import pyarrow.parquet
25import requests
26from chardet.universaldetector import UniversalDetector
27from tqdm import tqdm
29from receptiviti.status import _resolve_request_def, status
31CACHE = gettempdir() + "/receptiviti_cache/"
32REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/"
35def _manage_request(
36 text: Union[str, List[str], pandas.DataFrame, None] = None,
37 ids: Union[str, List[str], List[int], None] = None,
38 text_column: Union[str, None] = None,
39 id_column: Union[str, None] = None,
40 files: Union[List[str], None] = None,
41 directory: Union[str, None] = None,
42 file_type="txt",
43 encoding: Union[str, None] = None,
44 context="written",
45 api_args: Union[dict, None] = None,
46 bundle_size=1000,
47 bundle_byte_limit=75e5,
48 collapse_lines=False,
49 retry_limit=50,
50 request_cache=True,
51 cores=1,
52 in_memory: Union[bool, None] = None,
53 verbose=False,
54 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"),
55 make_request=True,
56 text_as_paths=False,
57 dotenv: Union[bool, str] = True,
58 cache=os.getenv("RECEPTIVITI_CACHE", ""),
59 cache_overwrite=False,
60 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet"),
61 key=os.getenv("RECEPTIVITI_KEY", ""),
62 secret=os.getenv("RECEPTIVITI_SECRET", ""),
63 url=os.getenv("RECEPTIVITI_URL", ""),
64 version=os.getenv("RECEPTIVITI_VERSION", ""),
65 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""),
66 to_norming=False,
67) -> tuple[pandas.DataFrame, pandas.DataFrame | None, bool]:
68 if cores > 1 and current_process().name != "MainProcess":
69 return (pandas.DataFrame(), None, False)
70 start_time = perf_counter()
72 if request_cache:
73 if verbose:
74 print(f"preparing request cache ({perf_counter() - start_time:.4f})")
75 _manage_request_cache()
77 # resolve credentials and check status
78 full_url, url, key, secret = _resolve_request_def(url, key, secret, dotenv)
79 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", full_url)
80 if url_parts:
81 from_url = url_parts.groups()
82 if not version and from_url[0] is not None:
83 version = from_url[0]
84 if not endpoint and from_url[1] is not None:
85 endpoint = from_url[1]
86 if to_norming:
87 version = "v2"
88 endpoint = "norming"
89 request_cache = False
90 else:
91 if not version:
92 version = os.getenv("RECEPTIVITI_VERSION", "v1")
93 version = version.lower()
94 if not version or not re.search("^v\\d+$", version):
95 msg = f"invalid version: {version}"
96 raise RuntimeError(msg)
97 if not endpoint:
98 endpoint_default = "framework" if version == "v1" else "analyze"
99 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default)
100 endpoint = re.sub("^.*/", "", endpoint).lower()
101 if not endpoint or re.search("[^a-z]", endpoint):
102 msg = f"invalid endpoint: {endpoint}"
103 raise RuntimeError(msg)
104 api_status = status(url, key, secret, dotenv, verbose=False)
105 if api_status is None or api_status.status_code != 200:
106 msg = (
107 "URL is not reachable"
108 if api_status is None
109 else f"API status failed: {api_status.status_code}: {api_status.reason}"
110 )
111 raise RuntimeError(msg)
113 # resolve text and ids
114 text_as_dir = False
115 if text is None:
116 if directory is not None:
117 text = directory
118 text_as_dir = True
119 elif files is not None:
120 text_as_paths = True
121 text = files
122 else:
123 msg = "enter text as the first argument, or use the `files` or `directory` arguments"
124 raise RuntimeError(msg)
125 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260):
126 if not text_as_dir and os.path.isfile(text):
127 if verbose:
128 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})")
129 text = _readin([text], text_column, id_column, collapse_lines, encoding)
130 if isinstance(text, pandas.DataFrame):
131 id_column = "ids"
132 text_column = "text"
133 text_as_paths = False
134 elif os.path.isdir(text):
135 text = glob(f"{text}/*{file_type}")
136 text_as_paths = True
137 elif os.path.isdir(os.path.dirname(text)):
138 msg = f"`text` appears to point to a directory, but it does not exist: {text}"
139 raise RuntimeError(msg)
140 if isinstance(text, pandas.DataFrame):
141 if id_column is not None:
142 if id_column in text:
143 ids = text[id_column].to_list()
144 else:
145 msg = f"`id_column` ({id_column}) is not in `text`"
146 raise IndexError(msg)
147 if text_column is not None:
148 if text_column in text:
149 text = text[text_column].to_list()
150 else:
151 msg = f"`text_column` ({text_column}) is not in `text`"
152 raise IndexError(msg)
153 else:
154 msg = "`text` is a DataFrame, but no `text_column` is specified"
155 raise RuntimeError(msg)
156 if isinstance(text, str):
157 text = [text]
158 text_is_path = all(isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text)
159 if text_as_paths and not text_is_path:
160 msg = "`text` treated as a list of files, but not all of the entries exist"
161 raise RuntimeError(msg)
162 if text_is_path and not collapse_lines:
163 ids = text
164 text = _readin(text, text_column, id_column, collapse_lines, encoding)
165 if isinstance(text, pandas.DataFrame):
166 if id_column is None:
167 ids = text["ids"].to_list()
168 elif id_column in text:
169 ids = text[id_column].to_list()
170 if text_column is None:
171 text_column = "text"
172 text = text[text_column].to_list()
173 text_is_path = False
174 if ids is None and text_is_path:
175 ids = text
177 id_specified = ids is not None
178 if ids is None:
179 ids = numpy.arange(1, len(text) + 1).tolist()
180 elif len(ids) != len(text):
181 msg = "`ids` is not the same length as `text`"
182 raise RuntimeError(msg)
183 original_ids = set(ids)
184 if len(ids) != len(original_ids):
185 msg = "`ids` contains duplicates"
186 raise RuntimeError(msg)
188 # prepare bundles
189 if verbose:
190 print(f"preparing text ({perf_counter() - start_time:.4f})")
191 data = pandas.DataFrame({"text": text, "id": ids})
192 data_subset = data[~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna())]
193 n_texts = len(data_subset)
194 if not n_texts:
195 msg = "no valid texts to process"
196 raise RuntimeError(msg)
197 bundle_size = max(1, bundle_size)
198 n_bundles = math.ceil(n_texts / min(1000, bundle_size))
199 groups = data_subset.groupby(
200 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts],
201 group_keys=False,
202 )
203 bundles = []
204 for _, group in groups:
205 if sys.getsizeof(group) > bundle_byte_limit:
206 start = current = end = 0
207 for txt in group["text"]:
208 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt)
209 if size > bundle_byte_limit:
210 msg = f"one of your texts is over the bundle size limit ({bundle_byte_limit / 1e6} MB)"
211 raise RuntimeError(msg)
212 if (current + size) > bundle_byte_limit:
213 bundles.append(group[start:end])
214 start = end = end + 1
215 current = size
216 else:
217 end += 1
218 current += size
219 bundles.append(group[start:])
220 else:
221 bundles.append(group)
222 n_bundles = len(bundles)
223 if verbose:
224 print(
225 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in "
226 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}",
227 f"({perf_counter() - start_time:.4f})",
228 )
230 # process bundles
231 opts = {
232 "url": (
233 full_url
234 if to_norming
235 else (
236 f"{url}/{version}/{endpoint}/bulk" if version == "v1" else f"{url}/{version}/{endpoint}/{context}"
237 ).lower()
238 ),
239 "version": version,
240 "auth": requests.auth.HTTPBasicAuth(key, secret),
241 "retries": retry_limit,
242 "add": {} if api_args is None else api_args,
243 "request_cache": request_cache,
244 "cache": "" if cache_overwrite else cache,
245 "cache_format": cache_format,
246 "to_norming": to_norming,
247 "make_request": make_request,
248 "text_is_path": text_is_path,
249 "text_column": text_column,
250 "id_column": id_column,
251 "collapse_lines": collapse_lines,
252 "encoding": encoding,
253 }
254 if version != "v1" and api_args:
255 opts["url"] += "?" + urllib.parse.urlencode(api_args)
256 opts["add_hash"] = hashlib.md5(
257 json.dumps(
258 {**opts["add"], "url": opts["url"], "key": key, "secret": secret},
259 separators=(",", ":"),
260 ).encode()
261 ).hexdigest()
262 if isinstance(progress_bar, str):
263 progress_bar = progress_bar == "True"
264 use_pb = (verbose and progress_bar) or progress_bar
265 parallel = n_bundles > 1 and cores > 1
266 if in_memory is None:
267 in_memory = not parallel
268 with TemporaryDirectory() as scratch_cache:
269 if not in_memory:
270 if verbose:
271 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})")
273 def write_to_scratch(i: int, bundle: pandas.DataFrame):
274 temp = f"{scratch_cache}/{i}.json"
275 with open(temp, "wb") as scratch:
276 pickle.dump(bundle, scratch)
277 return temp
279 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)]
280 if parallel:
281 if verbose:
282 print(f"requesting in parallel ({perf_counter() - start_time:.4f})")
283 waiter: "Queue[pandas.DataFrame]" = Queue()
284 queue: "Queue[tuple[int, pandas.DataFrame]]" = Queue()
285 manager = Process(
286 target=_queue_manager,
287 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose),
288 )
289 manager.start()
290 nb = math.ceil(n_bundles / min(n_bundles, cores))
291 cores = math.ceil(n_bundles / nb)
292 procs = [
293 Process(
294 target=_process,
295 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue),
296 )
297 for i in range(cores)
298 ]
299 for cl in procs:
300 cl.start()
301 res = waiter.get()
302 else:
303 if verbose:
304 print(f"requesting serially ({perf_counter() - start_time:.4f})")
305 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None
306 res = _process(bundles, opts, pb=pb)
307 if pb is not None:
308 pb.close()
309 if verbose:
310 print(f"done requesting ({perf_counter() - start_time:.4f})")
312 return (data, res, id_specified)
315def _queue_manager(
316 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]",
317 waiter: "Queue[pandas.DataFrame]",
318 n_texts: int,
319 n_bundles: int,
320 use_pb=True,
321 verbose=False,
322):
323 if use_pb:
324 pb = tqdm(total=n_texts, leave=verbose)
325 res: List[pandas.DataFrame] = []
326 for size, chunk in iter(queue.get, None):
327 if isinstance(chunk, pandas.DataFrame):
328 if use_pb:
329 pb.update(size)
330 res.append(chunk)
331 if len(res) >= n_bundles:
332 break
333 else:
334 break
335 waiter.put(pandas.concat(res, ignore_index=True, sort=False))
338def _process(
339 bundles: list,
340 opts: dict,
341 queue: Union["Queue[tuple[int, Union[pandas.DataFrame, None]]]", None] = None,
342 pb: Union[tqdm, None] = None,
343) -> pandas.DataFrame:
344 reses: List[pandas.DataFrame] = []
345 for bundle in bundles:
346 if isinstance(bundle, str):
347 with open(bundle, "rb") as scratch:
348 bundle = pickle.load(scratch)
349 body = []
350 bundle.insert(0, "text_hash", "")
351 if opts["text_is_path"]:
352 bundle["text"] = _readin(
353 bundle["text"],
354 opts["text_column"],
355 opts["id_column"],
356 opts["collapse_lines"],
357 opts["encoding"],
358 )
359 for i, text in enumerate(bundle["text"]):
360 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest()
361 bundle.iat[i, 0] = text_hash
362 if opts["version"] == "v1":
363 body.append({"content": text, "request_id": text_hash, **opts["add"]})
364 else:
365 body.append({"text": text, "request_id": text_hash})
366 cached: Union[pandas.DataFrame, None] = None
367 if opts["cache"] and os.listdir(opts["cache"]):
368 db = pyarrow.dataset.dataset(
369 opts["cache"],
370 partitioning=pyarrow.dataset.partitioning(
371 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive"
372 ),
373 format=opts["cache_format"],
374 )
375 if "text_hash" in db.schema.names:
376 su = db.filter(pyarrow.compute.field("text_hash").isin(bundle["text_hash"]))
377 if su.count_rows() > 0:
378 cached = su.to_table().to_pandas(split_blocks=True, self_destruct=True)
379 res: Union[str, pandas.DataFrame] = "failed to retrieve results"
380 json_body = json.dumps(body, separators=(",", ":"))
381 bundle_hash = hashlib.md5(json_body.encode()).hexdigest()
382 if cached is None or len(cached) < len(bundle):
383 if cached is None or not len(cached):
384 res = _prepare_results(json_body, bundle_hash, opts)
385 else:
386 fresh = ~pyarrow.compute.is_in(
387 bundle["text_hash"].to_list(), pyarrow.array(cached["text_hash"])
388 ).to_pandas(split_blocks=True, self_destruct=True)
389 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":"))
390 res = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts)
391 if not isinstance(res, str):
392 if cached is not None:
393 if res.ndim != cached.ndim or not all(cached.columns.isin(res.columns)):
394 json_body = json.dumps([body[i] for i, ck in enumerate(fresh) if ck], separators=(",", ":"))
395 cached = _prepare_results(json_body, hashlib.md5(json_body.encode()).hexdigest(), opts)
396 res = pandas.concat([res, cached])
397 if opts["cache"]:
398 writer = _get_writer(opts["cache_format"])
399 for id_bin, d in res.groupby("bin"):
400 bin_dir = f"{opts['cache']}/bin={id_bin}"
401 os.makedirs(bin_dir, exist_ok=True)
402 writer(
403 pyarrow.Table.from_pandas(d, preserve_index=False),
404 f"{bin_dir}/{bundle_hash}-0.{opts['cache_format']}",
405 )
406 else:
407 res = cached
408 if not isinstance(res, str):
409 if "text_hash" in res:
410 res = res.merge(bundle.loc[:, ["text_hash", "id"]], on="text_hash")
411 reses.append(res)
412 if queue is not None:
413 queue.put((0, None) if isinstance(res, str) else (len(res), res))
414 elif pb is not None:
415 pb.update(len(bundle))
416 if isinstance(res, str):
417 raise RuntimeError(res)
418 return reses[0] if len(reses) == 1 else pandas.concat(reses, ignore_index=True, sort=False)
421def _prepare_results(body: str, bundle_hash: str, opts: dict):
422 raw_res = _request(
423 body,
424 opts["url"],
425 opts["auth"],
426 opts["retries"],
427 REQUEST_CACHE + bundle_hash + ".json" if opts["request_cache"] else "",
428 opts["to_norming"],
429 opts["make_request"],
430 )
431 if isinstance(raw_res, str):
432 return raw_res
433 res = pandas.json_normalize(raw_res)
434 if "request_id" in res:
435 res.rename(columns={"request_id": "text_hash"}, inplace=True)
436 res.drop(
437 list({"response_id", "language", "version", "error"}.intersection(res.columns)),
438 axis="columns",
439 inplace=True,
440 )
441 res.insert(res.ndim, "bin", ["h" + h[0] for h in res["text_hash"]])
442 return res
445def _request(
446 body: str,
447 url: str,
448 auth: requests.auth.HTTPBasicAuth,
449 retries: int,
450 cache="",
451 to_norming=False,
452 execute=True,
453) -> Union[dict, str]:
454 if not os.path.isfile(cache):
455 if not execute:
456 return "`make_request` is `False`, but there are texts with no cached results"
457 if to_norming:
458 res = requests.patch(url, body, auth=auth, timeout=9999)
459 else:
460 res = requests.post(url, body, auth=auth, timeout=9999)
461 if cache and res.status_code == 200:
462 with open(cache, "w", encoding="utf-8") as response:
463 json.dump(res.json(), response)
464 else:
465 with open(cache, encoding="utf-8") as response:
466 data = json.load(response)
467 return data["results"] if "results" in data else data
468 data = res.json()
469 if res.status_code == 200:
470 data = dict(data[0] if isinstance(data, list) else data)
471 return data["results"] if "results" in data else data
472 if os.path.isfile(cache):
473 os.remove(cache)
474 if retries > 0 and "code" in data and data["code"] == 1420:
475 cd = re.search(
476 "[0-9]+(?:\\.[0-9]+)?",
477 (res.json()["message"] if res.headers["Content-Type"] == "application/json" else res.text),
478 )
479 sleep(1 if cd is None else float(cd[0]) / 1e3)
480 return _request(body, url, auth, retries - 1, cache, to_norming)
481 return f"request failed, and have no retries: {res.status_code}: {data['error'] if 'error' in data else res.reason}"
484def _manage_request_cache():
485 os.makedirs(REQUEST_CACHE, exist_ok=True)
486 try:
487 refreshed = time()
488 log_file = REQUEST_CACHE + "log.txt"
489 if os.path.exists(log_file):
490 with open(log_file, encoding="utf-8") as log:
491 logged = log.readline()
492 if isinstance(logged, list):
493 logged = logged[0]
494 refreshed = float(logged)
495 else:
496 with open(log_file, "w", encoding="utf-8") as log:
497 log.write(str(time()))
498 if time() - refreshed > 86400:
499 for cached_request in glob(REQUEST_CACHE + "*.json"):
500 os.remove(cached_request)
501 except Exception as exc:
502 warnings.warn(UserWarning(f"failed to manage request cache: {exc}"), stacklevel=2)
505def _readin(
506 paths: List[str],
507 text_column: Union[str, None],
508 id_column: Union[str, None],
509 collapse_lines: bool,
510 encoding: Union[str, None],
511) -> Union[List[str], pandas.DataFrame]:
512 text = []
513 ids = []
514 sel = []
515 if text_column is not None:
516 sel.append(text_column)
517 if id_column is not None:
518 sel.append(id_column)
519 enc = encoding
520 predict_encoding = enc is None
521 if predict_encoding:
522 detect = UniversalDetector()
524 def handle_encoding(file: str):
525 detect.reset()
526 with open(file, "rb") as text:
527 while True:
528 chunk = text.read(1024)
529 if not chunk:
530 break
531 detect.feed(chunk)
532 if detect.done:
533 break
534 detected = detect.close()["encoding"]
535 if detected is None:
536 msg = "failed to detect encoding; please specify with the `encoding` argument"
537 raise RuntimeError(msg)
538 return detected
540 if os.path.splitext(paths[0])[1] == ".txt" and not sel:
541 if collapse_lines:
542 for file in paths:
543 if predict_encoding:
544 enc = handle_encoding(file)
545 with open(file, encoding=enc, errors="ignore") as texts:
546 text.append(" ".join([line.rstrip() for line in texts]))
547 else:
548 for file in paths:
549 if predict_encoding:
550 enc = handle_encoding(file)
551 with open(file, encoding=enc, errors="ignore") as texts:
552 lines = [line.rstrip() for line in texts]
553 text += lines
554 ids += [file] if len(lines) == 1 else [file + str(i + 1) for i in range(len(lines))]
555 return pandas.DataFrame({"text": text, "ids": ids})
556 elif collapse_lines:
557 for file in paths:
558 if predict_encoding:
559 enc = handle_encoding(file)
560 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
561 text.append(" ".join(temp[text_column]))
562 else:
563 for file in paths:
564 if predict_encoding:
565 enc = handle_encoding(file)
566 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
567 if text_column not in temp:
568 msg = f"`text_column` ({text_column}) was not found in all files"
569 raise IndexError(msg)
570 text += temp[text_column].to_list()
571 ids += (
572 temp[id_column].to_list()
573 if id_column is not None
574 else [file] if len(temp) == 1 else [file + str(i + 1) for i in range(len(temp))]
575 )
576 return pandas.DataFrame({"text": text, "ids": ids})
577 return text
580def _get_writer(write_format: str):
581 if write_format == "parquet":
582 return pyarrow.parquet.write_table
583 if write_format == "feather":
584 return pyarrow.feather.write_feather