Coverage for src/receptiviti/request.py: 79%
451 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-15 16:41 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-15 16:41 -0700
1"""Make requests to the API."""
3import hashlib
4import json
5import math
6import os
7import pickle
8import re
9import shutil
10import sys
11from glob import glob
12from multiprocessing import Process, Queue, current_process
13from tempfile import TemporaryDirectory, gettempdir
14from time import perf_counter, sleep, time
15from typing import List, Union
17import numpy
18import pandas
19import pyarrow
20import requests
21from chardet.universaldetector import UniversalDetector
22from pyarrow import compute, dataset
23from tqdm import tqdm
25from receptiviti.readin_env import readin_env
26from receptiviti.status import status
28CACHE = gettempdir() + "/receptiviti_cache/"
29REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/"
32def request(
33 text: Union[str, List[str], pandas.DataFrame, None] = None,
34 output: Union[str, None] = None,
35 ids: Union[str, List[str], List[int], None] = None,
36 text_column: Union[str, None] = None,
37 id_column: Union[str, None] = None,
38 files: Union[List[str], None] = None,
39 directory: Union[str, None] = None,
40 file_type: str = "txt",
41 encoding: Union[str, None] = None,
42 return_text=False,
43 api_args: Union[dict, None] = None,
44 frameworks: Union[str, List[str], None] = None,
45 framework_prefix: Union[bool, None] = None,
46 bundle_size=1000,
47 bundle_byte_limit=75e5,
48 collapse_lines=False,
49 retry_limit=50,
50 clear_cache=False,
51 request_cache=True,
52 cores=1,
53 in_memory: Union[bool, None] = None,
54 verbose=False,
55 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"),
56 overwrite=False,
57 make_request=True,
58 text_as_paths=False,
59 dotenv: Union[bool, str] = True,
60 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""),
61 cache_overwrite=False,
62 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""),
63 key=os.getenv("RECEPTIVITI_KEY", ""),
64 secret=os.getenv("RECEPTIVITI_SECRET", ""),
65 url=os.getenv("RECEPTIVITI_URL", ""),
66 version=os.getenv("RECEPTIVITI_VERSION", ""),
67 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""),
68) -> pandas.DataFrame:
69 """
70 Send texts to be scored by the API.
72 Args:
73 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of
74 strings containing the text itself, or the path to a file from which to read in text.
75 If a DataFrame, `text_column` is used to extract such a vector. A string may also
76 represent a directory in which to search for files. To best ensure paths are not
77 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter
78 a directory path, or `files` to enter a vector of file paths.
79 output (str): Path to a file to write results to.
80 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text`
81 containing IDs.
82 text_column (str): Column name in `text` containing text.
83 id_column (str): Column name in `text` containing IDs.
84 files (list[str]): Vector of file paths, as alternate entry to `text`.
85 directory (str): A directory path to search for files in, as alternate entry to `text`.
86 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`).
87 encoding (str | None): Encoding of file(s) to be read in; one of the
88 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings).
89 If this is `None` (default), encoding will be predicted for each file, but this can
90 potentially fail, resulting in mis-encoded characters. For best (and fastest) results,
91 specify encoding.
92 return_text (bool): If `True`, will include a `text` column in the output with the
93 original text.
94 api_args (dict): Additional arguments to include in the request.
95 frameworks (str | list): One or more names of frameworks to return.
96 framework_prefix (bool): If `False`, will drop framework prefix from column names.
97 If one framework is selected, will default to `False`.
98 bundle_size (int): Maximum number of texts per bundle.
99 bundle_byte_limit (float): Maximum byte size of each bundle.
100 collapse_lines (bool): If `True`, will treat files as containing single texts, and
101 collapse multiple lines.
102 retry_limit (int): Number of times to retry a failed request.
103 clear_cache (bool): If `True`, will delete the `cache` before processing.
104 request_cache (bool): If `False`, will not temporarily save raw requests for reuse
105 within a day.
106 cores (int): Number of CPU cores to use when processing multiple bundles.
107 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when
108 processed. Defaults to `True` when processing in parallel.
109 verbose (bool): If `True`, will print status messages and preserve the progress bar.
110 progress_bar (str | bool): If `False`, will not display a progress bar.
111 overwrite (bool): If `True`, will overwrite an existing `output` file.
112 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths.
113 Otherwise, this will be detected.
114 dotenv (bool | str): Path to a .env file to read environment variables from. By default,
115 will for a file in the current directory or `~/Documents`.
116 Passed to `readin_env` as `path`.
117 cache (bool | str): Path to a cache directory, or `True` to use the default directory.
118 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts,
119 but will store results in the cache (unlike `cache = False`).
120 cache_format (str): File format of the cache, of available Arrow formats.
121 key (str): Your API key.
122 secret (str): Your API secret.
123 url (str): The URL of the API; defaults to `https://api.receptiviti.com`.
124 version (str): Version of the API; defaults to `v1`.
125 endpoint (str): Endpoint of the API; defaults to `framework`.
127 Returns:
128 Scores associated with each input text.
130 Cache:
131 If `cache` is specified, results for unique texts are saved in an Arrow database
132 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with
133 subsequent requests. This ensures that the exact same texts are not re-sent to the API.
134 This does, however, add some processing time and disc space usage.
136 If `cache` if `True`, a default directory (`receptiviti_cache`) will be
137 looked for in the system's temporary directory (`tempfile.gettempdir()`).
139 The primary cache is checked when each bundle is processed, and existing results are
140 loaded at that time. When processing many bundles in parallel, and many results have
141 been cached, this can cause the system to freeze and potentially crash.
142 To avoid this, limit the number of cores, or disable parallel processing.
144 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be
145 used to adjust the format of the cache.
147 You can use the cache independently with
148 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`.
150 You can also set the `clear_cache` argument to `True` to clear the cache before it is used
151 again, which may be useful if the cache has gotten big, or you know new results will be
152 returned.
154 Even if a cached result exists, it will be reprocessed if it does not have all of the
155 variables of new results, but this depends on there being at least 1 uncached result. If,
156 for instance, you add a framework to your account and want to reprocess a previously
157 processed set of texts, you would need to first clear the cache.
159 Either way, duplicated texts within the same call will only be sent once.
161 The `request_cache` argument controls a more temporary cache of each bundle request. This
162 is cleared after a day. You might want to set this to `False` if a new framework becomes
163 available on your account and you want to process a set of text you re-processed recently.
165 Another temporary cache is made when `in_memory` is `False`, which is the default when
166 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a
167 temporary directory that contains a file for each unique bundle, which is read in as needed
168 by the parallel workers.
170 Parallelization:
171 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents
172 a single request to the API, which is why they are limited to 1000 texts and a total size
173 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are
174 processed by multiple cores.
176 If you have texts spread across multiple files, they can be most efficiently processed in
177 parallel if each file contains a single text (potentially collapsed from multiple lines).
178 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be
179 read in before bundling in order to ensure bundles are under the length limit.
181 If you are calling this function from a script, parallelization will involve rerunning
182 that script in each process, so anything you don't want rerun should be protected by
183 a check that `__name__` equals `"__main__"`
184 (placed within an `if __name__ == "__main__":` clause).
185 """
186 if cores > 1 and current_process().name != "MainProcess":
187 return
188 if output is not None and os.path.isfile(output) and not overwrite:
189 msg = "`output` file already exists; use `overwrite=True` to overwrite it"
190 raise RuntimeError(msg)
191 start_time = perf_counter()
193 if request_cache:
194 if verbose:
195 print(f"preparing request cache ({perf_counter() - start_time:.4f})")
196 _manage_request_cache()
198 # resolve credentials and check status
199 if dotenv:
200 readin_env("." if isinstance(dotenv, bool) else dotenv)
201 if not url:
202 url = os.getenv("RECEPTIVITI_URL", "https://api.receptiviti.com")
203 url_parts = re.search("/([Vv]\\d+)/?([^/]+)?", url)
204 if url_parts:
205 from_url = url_parts.groups()
206 if not version and from_url[0] is not None:
207 version = from_url[0]
208 if not endpoint and from_url[1] is not None:
209 endpoint = from_url[1]
210 url = ("https://" if re.match("http", url, re.I) is None else "") + re.sub(
211 "/+[Vv]\\d+(?:/.*)?$|/+$", "", url
212 )
213 if not key:
214 key = os.getenv("RECEPTIVITI_KEY", "")
215 if not secret:
216 secret = os.getenv("RECEPTIVITI_SECRET", "")
217 if not version:
218 version = os.getenv("RECEPTIVITI_VERSION", "v1")
219 if not endpoint:
220 endpoint_default = "framework" if version.lower() == "v1" else "taxonomies"
221 endpoint = os.getenv("RECEPTIVITI_ENDPOINT", endpoint_default)
222 api_status = status(url, key, secret, dotenv, verbose=False)
223 if not api_status or api_status.status_code != 200:
224 msg = (
225 f"API status failed: {api_status.status_code}: {api_status.reason}"
226 if api_status
227 else "URL is not reachable"
228 )
229 raise RuntimeError(msg)
231 # resolve text and ids
232 text_as_dir = False
233 if text is None:
234 if directory is not None:
235 text = directory
236 text_as_dir = True
237 elif files is not None:
238 text_as_paths = True
239 text = files
240 else:
241 msg = "enter text as the first argument, or use the `files` or `directory` arguments"
242 raise RuntimeError(msg)
243 if isinstance(text, str) and (text_as_dir or text_as_paths or len(text) < 260):
244 if not text_as_dir and os.path.isfile(text):
245 if verbose:
246 print(f"reading in texts from a file ({perf_counter() - start_time:.4f})")
247 text = _readin([text], text_column, id_column, collapse_lines, encoding)
248 if isinstance(text, pandas.DataFrame):
249 id_column = "ids"
250 text_column = "text"
251 text_as_paths = False
252 elif os.path.isdir(text):
253 text = glob(f"{text}/*{file_type}")
254 text_as_paths = True
255 elif os.path.isdir(os.path.dirname(text)):
256 msg = f"`text` appears to point to a directory, but it does not exist: {text}"
257 raise RuntimeError(msg)
258 if isinstance(text, pandas.DataFrame):
259 if id_column is not None:
260 if id_column in text:
261 ids = text[id_column].to_list()
262 else:
263 msg = f"`id_column` ({id_column}) is not in `text`"
264 raise IndexError(msg)
265 if text_column is not None:
266 if text_column in text:
267 text = text[text_column].to_list()
268 else:
269 msg = f"`text_column` ({text_column}) is not in `text`"
270 raise IndexError(msg)
271 else:
272 msg = "`text` is a DataFrame, but no `text_column` is specified"
273 raise RuntimeError(msg)
274 if isinstance(text, str):
275 text = [text]
276 text_is_path = all(
277 isinstance(t, str) and (text_as_paths or len(t) < 260) and os.path.isfile(t) for t in text
278 )
279 if text_as_paths and not text_is_path:
280 msg = "`text` treated as a list of files, but not all of the entries exist"
281 raise RuntimeError(msg)
282 if text_is_path and not collapse_lines:
283 ids = text
284 text = _readin(text, text_column, id_column, collapse_lines, encoding)
285 if isinstance(text, pandas.DataFrame):
286 if id_column is None:
287 ids = text["ids"].to_list()
288 elif id_column in text:
289 ids = text[id_column].to_list()
290 if text_column is None:
291 text_column = "text"
292 text = text[text_column].to_list()
293 text_is_path = False
294 if ids is None and text_is_path:
295 ids = text
297 id_specified = ids is not None
298 if ids is None:
299 ids = numpy.arange(1, len(text) + 1).tolist()
300 elif len(ids) != len(text):
301 msg = "`ids` is not the same length as `text`"
302 raise RuntimeError(msg)
303 original_ids = set(ids)
304 if len(ids) != len(original_ids):
305 msg = "`ids` contains duplicates"
306 raise RuntimeError(msg)
308 # prepare bundles
309 if verbose:
310 print(f"preparing text ({perf_counter() - start_time:.4f})")
311 data = pandas.DataFrame({"text": text, "id": ids})
312 n_original = len(data)
313 data_subset = data[
314 ~(data.duplicated(subset=["text"]) | (data["text"] == "") | data["text"].isna())
315 ]
316 n_texts = len(data_subset)
317 if not n_texts:
318 msg = "no valid texts to process"
319 raise RuntimeError(msg)
320 bundle_size = max(1, bundle_size)
321 n_bundles = math.ceil(n_texts / min(1000, bundle_size))
322 groups = data_subset.groupby(
323 numpy.sort(numpy.tile(numpy.arange(n_bundles) + 1, bundle_size))[:n_texts],
324 group_keys=False,
325 )
326 bundles = []
327 for _, group in groups:
328 if sys.getsizeof(group) > bundle_byte_limit:
329 start = current = end = 0
330 for txt in group["text"]:
331 size = os.stat(txt).st_size if text_is_path else sys.getsizeof(txt)
332 if size > bundle_byte_limit:
333 msg = (
334 "one of your texts is over the bundle size"
335 f" limit ({bundle_byte_limit / 1e6} MB)"
336 )
337 raise RuntimeError(msg)
338 if (current + size) > bundle_byte_limit:
339 bundles.append(group[start:end])
340 start = end = end + 1
341 current = size
342 else:
343 end += 1
344 current += size
345 bundles.append(group[start:])
346 else:
347 bundles.append(group)
348 n_bundles = len(bundles)
349 if verbose:
350 print(
351 f"prepared {n_texts} unique text{'s' if n_texts > 1 else ''} in "
352 f"{n_bundles} {'bundles' if n_bundles > 1 else 'bundle'}",
353 f"({perf_counter() - start_time:.4f})",
354 )
356 # process bundles
357 if isinstance(cache, str):
358 if cache:
359 if clear_cache and os.path.exists(cache):
360 shutil.rmtree(cache, True)
361 os.makedirs(cache, exist_ok=True)
362 if not cache_format:
363 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet")
364 else:
365 cache = False
366 opts = {
367 "url": f"{url}/{version}/{endpoint}/bulk".lower(),
368 "auth": requests.auth.HTTPBasicAuth(key, secret),
369 "retries": retry_limit,
370 "add": {} if api_args is None else api_args,
371 "request_cache": request_cache,
372 "cache": "" if cache_overwrite or isinstance(cache, bool) and not cache else cache,
373 "cache_format": cache_format,
374 "make_request": make_request,
375 "text_is_path": text_is_path,
376 "text_column": text_column,
377 "id_column": id_column,
378 "collapse_lines": collapse_lines,
379 "encoding": encoding,
380 }
381 opts["add_hash"] = hashlib.md5(
382 json.dumps(
383 {**opts["add"], "url": opts["url"], "key": key, "secret": secret},
384 separators=(",", ":"),
385 ).encode()
386 ).hexdigest()
387 if isinstance(progress_bar, str):
388 progress_bar = progress_bar == "True"
389 use_pb = (verbose and progress_bar) or progress_bar
390 parallel = n_bundles > 1 and cores > 1
391 if in_memory is None:
392 in_memory = not parallel
393 with TemporaryDirectory() as scratch_cache:
394 if not in_memory:
395 if verbose:
396 print(f"writing to scratch cache ({perf_counter() - start_time:.4f})")
398 def write_to_scratch(i: int, bundle: pandas.DataFrame):
399 temp = f"{scratch_cache}/{i}.json"
400 with open(temp, "wb") as scratch:
401 pickle.dump(bundle, scratch)
402 return temp
404 bundles = [write_to_scratch(i, b) for i, b in enumerate(bundles)]
405 if parallel:
406 if verbose:
407 print(f"requesting in parallel ({perf_counter() - start_time:.4f})")
408 waiter: "Queue[pandas.DataFrame]" = Queue()
409 queue: "Queue[tuple[int, pandas.DataFrame]]" = Queue()
410 manager = Process(
411 target=_queue_manager,
412 args=(queue, waiter, n_texts, n_bundles, use_pb, verbose),
413 )
414 manager.start()
415 nb = math.ceil(n_bundles / min(n_bundles, cores))
416 cores = math.ceil(n_bundles / nb)
417 procs = [
418 Process(
419 target=_process,
420 args=(bundles[(i * nb) : min(n_bundles, (i + 1) * nb)], opts, queue),
421 )
422 for i in range(cores)
423 ]
424 for cl in procs:
425 cl.start()
426 res = waiter.get()
427 else:
428 if verbose:
429 print(f"requesting serially ({perf_counter() - start_time:.4f})")
430 pb = tqdm(total=n_texts, leave=verbose) if use_pb else None
431 res = _process(bundles, opts, pb=pb)
432 if pb is not None:
433 pb.close()
434 if verbose:
435 print(f"done requesting ({perf_counter() - start_time:.4f})")
437 # finalize
438 if not res.shape[0]:
439 msg = "no results"
440 raise RuntimeError(msg)
441 if isinstance(cache, str):
442 _update_cache(res, cache, cache_format, verbose, start_time, [e[0] for e in opts["add"]])
443 if verbose:
444 print(f"preparing output ({perf_counter() - start_time:.4f})")
445 data.set_index("id", inplace=True)
446 res.set_index("id", inplace=True)
447 if len(res) != n_original:
448 res = res.join(data["text"])
449 data_absent = data.loc[list(set(data.index).difference(res.index))]
450 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])]
451 if data.size:
452 res = res.reset_index()
453 res.set_index("text", inplace=True)
454 data_dupes = res.loc[data_absent["text"]]
455 data_dupes["id"] = data_absent.index.to_list()
456 res = pandas.concat([res, data_dupes])
457 res.reset_index(inplace=True, drop=True)
458 res.set_index("id", inplace=True)
459 res = res.join(data["text"], how="right")
460 if not return_text:
461 res.drop("text", axis=1, inplace=True)
462 res = res.reset_index()
464 if output is not None:
465 if verbose:
466 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})")
467 res.to_csv(output, index=False)
469 drops = ["custom", "bin"]
470 if not id_specified:
471 drops.append("id")
472 res.drop(
473 list({*drops}.intersection(res.columns)),
474 axis="columns",
475 inplace=True,
476 )
477 if frameworks is not None:
478 if verbose:
479 print(f"selecting frameworks ({perf_counter() - start_time:.4f})")
480 if isinstance(frameworks, str):
481 frameworks = [frameworks]
482 if len(frameworks) == 1 and framework_prefix is None:
483 framework_prefix = False
484 select = []
485 if id_specified:
486 select.append("id")
487 if return_text:
488 select.append("text")
489 select.append("text_hash")
490 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)")
491 if isinstance(framework_prefix, bool) and not framework_prefix:
492 prefix_pattern = re.compile("^[^.]+\\.")
493 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns])
495 if verbose:
496 print(f"done ({perf_counter() - start_time:.4f})")
498 return res
501def _queue_manager(
502 queue: "Queue[tuple[int, Union[pandas.DataFrame, None]]]",
503 waiter: "Queue[pandas.DataFrame]",
504 n_texts: int,
505 n_bundles: int,
506 use_pb=True,
507 verbose=False,
508):
509 if use_pb:
510 pb = tqdm(total=n_texts, leave=verbose)
511 res: List[pandas.DataFrame] = []
512 for size, chunk in iter(queue.get, None):
513 if isinstance(chunk, pandas.DataFrame):
514 if use_pb:
515 pb.update(size)
516 res.append(chunk)
517 if len(res) >= n_bundles:
518 break
519 else:
520 break
521 waiter.put(pandas.concat(res, ignore_index=True, sort=False))
524def _process(
525 bundles: list,
526 opts: dict,
527 queue: Union["Queue[tuple[int, Union[pandas.DataFrame, None]]]", None] = None,
528 pb: Union[tqdm, None] = None,
529) -> pandas.DataFrame:
530 reses: List[pandas.DataFrame] = []
531 for bundle in bundles:
532 if isinstance(bundle, str):
533 with open(bundle, "rb") as scratch:
534 bundle = pickle.load(scratch)
535 body = []
536 bundle.insert(0, "text_hash", "")
537 if opts["text_is_path"]:
538 bundle["text"] = _readin(
539 bundle["text"],
540 opts["text_column"],
541 opts["id_column"],
542 opts["collapse_lines"],
543 opts["encoding"],
544 )
545 for i, text in enumerate(bundle["text"]):
546 text_hash = hashlib.md5((opts["add_hash"] + text).encode()).hexdigest()
547 bundle.iat[i, 0] = text_hash
548 body.append({"content": text, "request_id": text_hash, **opts["add"]})
549 cached = None
550 if opts["cache"] and os.path.isdir(opts["cache"] + "/bin=h"):
551 db = dataset.dataset(
552 opts["cache"],
553 partitioning=dataset.partitioning(
554 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive"
555 ),
556 format=opts["cache_format"],
557 )
558 if "text_hash" in db.schema.names:
559 su = db.filter(compute.field("text_hash").isin(bundle["text_hash"]))
560 if su.count_rows() > 0:
561 cached = su.to_table().to_pandas(split_blocks=True, self_destruct=True)
562 res = "failed to retrieve results"
563 if cached is None or len(cached) < len(bundle):
564 if cached is None or not len(cached):
565 res = _prepare_results(body, opts)
566 else:
567 fresh = ~compute.is_in(
568 bundle["text_hash"].to_list(), pyarrow.array(cached["text_hash"])
569 ).to_pandas(split_blocks=True, self_destruct=True)
570 res = _prepare_results([body[i] for i, ck in enumerate(fresh) if ck], opts)
571 if not isinstance(res, str):
572 if cached is not None:
573 if len(res) != len(cached) or not all(cached.columns.isin(res.columns)):
574 cached = _prepare_results(
575 [body[i] for i, ck in enumerate(fresh) if not ck], opts
576 )
577 res = pandas.concat([res, cached])
578 else:
579 res = cached
580 if not isinstance(res, str):
581 res = res.merge(bundle.loc[:, ["text_hash", "id"]], on="text_hash")
582 reses.append(res)
583 if queue is not None:
584 queue.put((0, None) if isinstance(res, str) else (len(res), res))
585 elif pb is not None:
586 pb.update(len(bundle))
587 if isinstance(res, str):
588 raise RuntimeError(res)
589 return reses[0] if len(reses) == 1 else pandas.concat(reses, ignore_index=True, sort=False)
592def _prepare_results(body: list, opts: dict):
593 json_body = json.dumps(body, separators=(",", ":"))
594 bundle_hash = (
595 REQUEST_CACHE + hashlib.md5(json_body.encode()).hexdigest() + ".json"
596 if opts["request_cache"]
597 else ""
598 )
599 raw_res = _request(
600 json_body,
601 opts["url"],
602 opts["auth"],
603 opts["retries"],
604 bundle_hash,
605 opts["make_request"],
606 )
607 if isinstance(raw_res, str):
608 return raw_res
609 res = pandas.json_normalize(raw_res)
610 res.rename(columns={"request_id": "text_hash"}, inplace=True)
611 if "text_hash" not in res:
612 res.insert(0, "text_hash", [text["request_id"] for text in body])
613 res.drop(
614 list({"response_id", "language", "version", "error"}.intersection(res.columns)),
615 axis="columns",
616 inplace=True,
617 )
618 res.insert(res.shape[1], "bin", ["h" + h[0] for h in res["text_hash"]])
619 return res
622def _request(
623 body: str,
624 url: str,
625 auth: requests.auth.HTTPBasicAuth,
626 retries: int,
627 cache="",
628 execute=True,
629) -> Union[dict, str]:
630 if not os.path.isfile(cache):
631 if not execute:
632 return "`make_request` is `False`, but there are texts with no cached results"
633 res = requests.post(url, body, auth=auth, timeout=9999)
634 if cache and res.status_code == 200:
635 with open(cache, "w", encoding="utf-8") as response:
636 json.dump(res.json(), response)
637 else:
638 with open(cache, encoding="utf-8") as response:
639 data = json.load(response)
640 return data["results"] if "results" in data else data
641 if res.status_code == 200:
642 data = res.json()
643 data = dict(data[0] if isinstance(data, list) else data)
644 return data["results"] if "results" in data else data
645 if os.path.isfile(cache):
646 os.remove(cache)
647 if retries > 0:
648 cd = re.search(
649 "[0-9]+(?:\\.[0-9]+)?",
650 (
651 res.json()["message"]
652 if res.headers["Content-Type"] == "application/json"
653 else res.text
654 ),
655 )
656 sleep(1 if cd is None else float(cd[0]) / 1e3)
657 return _request(body, url, auth, retries - 1, cache)
658 return f"request failed, and have no retries: {res.status_code}: {res.reason}"
661def _update_cache(
662 res: pandas.DataFrame,
663 cache: str,
664 cache_format: str,
665 verbose: bool,
666 start_time: float,
667 add_names: list,
668):
669 part: pyarrow.Partitioning = dataset.partitioning(
670 pyarrow.schema([pyarrow.field("bin", pyarrow.string())]), flavor="hive"
671 )
672 exclude = {"id", *add_names}
674 def initialize_cache():
675 initial = res.iloc[[0]].drop(
676 exclude.intersection(res.columns),
677 axis="columns",
678 )
679 initial["text_hash"] = ""
680 initial["bin"] = "h"
681 initial.loc[
682 :,
683 ~initial.columns.isin(["summary.word_count", "summary.sentence_count"])
684 & (initial.dtypes != object).to_list(),
685 ] = 0.1
686 dataset.write_dataset(
687 pyarrow.Table.from_pandas(initial),
688 cache,
689 partitioning=part,
690 format=cache_format,
691 existing_data_behavior="overwrite_or_ignore",
692 )
694 if not os.path.isdir(cache + "/bin=h"):
695 if verbose:
696 print(f"initializing cache ({perf_counter() - start_time:.4f})")
697 initialize_cache()
698 db = dataset.dataset(cache, partitioning=part, format=cache_format)
699 if any(name not in exclude and name not in db.schema.names for name in res.columns.to_list()):
700 if verbose:
701 print(
702 "clearing cache since it contains columns not in new results",
703 f"({perf_counter() - start_time:.4f})",
704 )
705 shutil.rmtree(cache, True)
706 initialize_cache()
707 db = dataset.dataset(cache, partitioning=part, format=cache_format)
708 fresh = res[~res.duplicated(subset=["text_hash"])]
709 su = db.filter(compute.field("text_hash").isin(fresh["text_hash"]))
710 if su.count_rows() > 0:
711 cached = ~compute.is_in(
712 fresh["text_hash"].to_list(),
713 su.scanner(columns=["text_hash"]).to_table()["text_hash"],
714 ).to_pandas(split_blocks=True, self_destruct=True)
715 if any(cached):
716 fresh = fresh[cached.to_list()]
717 else:
718 return
719 n_new = len(fresh)
720 if n_new:
721 if verbose:
722 print(
723 f"adding {n_new} result{'' if n_new == 1 else 's'}",
724 f"to cache ({perf_counter() - start_time:.4f})",
725 )
726 dataset.write_dataset(
727 pyarrow.Table.from_pandas(
728 fresh.drop(
729 list(exclude.intersection(fresh.columns)),
730 axis="columns",
731 )
732 ),
733 cache,
734 partitioning=part,
735 format=cache_format,
736 existing_data_behavior="overwrite_or_ignore",
737 )
740def _manage_request_cache():
741 os.makedirs(REQUEST_CACHE, exist_ok=True)
742 try:
743 refreshed = time()
744 log_file = REQUEST_CACHE + "log.txt"
745 if os.path.exists(log_file):
746 with open(log_file, encoding="utf-8") as log:
747 logged = log.readline()
748 if isinstance(logged, list):
749 logged = logged[0]
750 refreshed = float(logged)
751 else:
752 with open(log_file, "w", encoding="utf-8") as log:
753 log.write(str(time()))
754 if time() - refreshed > 86400:
755 for cached_request in glob(REQUEST_CACHE + "*.json"):
756 os.remove(cached_request)
757 except Exception as exc:
758 msg = "failed to manage request cache"
759 raise RuntimeWarning(msg) from exc
762def _readin(
763 paths: List[str],
764 text_column: Union[str, None],
765 id_column: Union[str, None],
766 collapse_lines: bool,
767 encoding: Union[str, None],
768) -> Union[List[str], pandas.DataFrame]:
769 text = []
770 ids = []
771 sel = []
772 if text_column is not None:
773 sel.append(text_column)
774 if id_column is not None:
775 sel.append(id_column)
776 enc = encoding
777 predict_encoding = enc is None
778 if predict_encoding:
779 detect = UniversalDetector()
781 def handle_encoding(file: str):
782 detect.reset()
783 with open(file, "rb") as text:
784 detect.feed(text.read())
785 return detect.close()["encoding"]
787 if os.path.splitext(paths[0])[1] == ".txt" and not sel:
788 if collapse_lines:
789 for file in paths:
790 if predict_encoding:
791 enc = handle_encoding(file)
792 with open(file, encoding=enc, errors="ignore") as texts:
793 text.append(" ".join([line.rstrip() for line in texts]))
794 else:
795 for file in paths:
796 if predict_encoding:
797 enc = handle_encoding(file)
798 with open(file, encoding=enc, errors="ignore") as texts:
799 lines = [line.rstrip() for line in texts]
800 text += lines
801 ids += (
802 [file]
803 if len(lines) == 1
804 else [file + str(i + 1) for i in range(len(lines))]
805 )
806 return pandas.DataFrame({"text": text, "ids": ids})
807 elif collapse_lines:
808 for file in paths:
809 if predict_encoding:
810 enc = handle_encoding(file)
811 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
812 text.append(" ".join(temp[text_column]))
813 else:
814 for file in paths:
815 if predict_encoding:
816 enc = handle_encoding(file)
817 temp = pandas.read_csv(file, encoding=enc, usecols=sel)
818 if text_column not in temp:
819 msg = f"`text_column` ({text_column}) was not found in all files"
820 raise IndexError(msg)
821 text += temp[text_column].to_list()
822 ids += (
823 temp[id_column].to_list()
824 if id_column is not None
825 else [file] if len(temp) == 1 else [file + str(i + 1) for i in range(len(temp))]
826 )
827 return pandas.DataFrame({"text": text, "ids": ids})
828 return text