Coverage for src/receptiviti/request.py: 87%
131 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-11 18:12 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-11 18:12 -0500
1"""Make requests to the API."""
3import os
4import re
5import shutil
6from glob import glob
7from math import ceil
8from multiprocessing import current_process
9from tempfile import gettempdir
10from time import perf_counter, time
11from typing import List, Union
13import pandas
14import pyarrow.dataset
16from receptiviti.frameworks import frameworks as get_frameworks
17from receptiviti.manage_request import _get_writer, _manage_request
18from receptiviti.norming import norming
19from receptiviti.readin_env import readin_env
21CACHE = gettempdir() + "/receptiviti_cache/"
22REQUEST_CACHE = gettempdir() + "/receptiviti_request_cache/"
25def request(
26 text: Union[str, List[str], pandas.DataFrame, None] = None,
27 output: Union[str, None] = None,
28 ids: Union[str, List[str], List[int], None] = None,
29 text_column: Union[str, None] = None,
30 id_column: Union[str, None] = None,
31 files: Union[List[str], None] = None,
32 directory: Union[str, None] = None,
33 file_type: str = "txt",
34 encoding: Union[str, None] = None,
35 return_text=False,
36 context="written",
37 custom_context: Union[str, bool] = False,
38 api_args: Union[dict, None] = None,
39 frameworks: Union[str, List[str], None] = None,
40 framework_prefix: Union[bool, None] = None,
41 bundle_size=1000,
42 bundle_byte_limit=75e5,
43 collapse_lines=False,
44 retry_limit=50,
45 clear_cache=False,
46 request_cache=True,
47 cores=1,
48 in_memory: Union[bool, None] = None,
49 verbose=False,
50 progress_bar: Union[str, bool] = os.getenv("RECEPTIVITI_PB", "True"),
51 overwrite=False,
52 make_request=True,
53 text_as_paths=False,
54 dotenv: Union[bool, str] = True,
55 cache: Union[str, bool] = os.getenv("RECEPTIVITI_CACHE", ""),
56 cache_overwrite=False,
57 cache_format=os.getenv("RECEPTIVITI_CACHE_FORMAT", ""),
58 key=os.getenv("RECEPTIVITI_KEY", ""),
59 secret=os.getenv("RECEPTIVITI_SECRET", ""),
60 url=os.getenv("RECEPTIVITI_URL", ""),
61 version=os.getenv("RECEPTIVITI_VERSION", ""),
62 endpoint=os.getenv("RECEPTIVITI_ENDPOINT", ""),
63) -> pandas.DataFrame | None:
64 """
65 Send texts to be scored by the API.
67 Args:
68 text (str | list[str] | pandas.DataFrame): Text to be processed, as a string or vector of
69 strings containing the text itself, or the path to a file from which to read in text.
70 If a DataFrame, `text_column` is used to extract such a vector. A string may also
71 represent a directory in which to search for files. To best ensure paths are not
72 treated as texts, either set `text_as_path` to `True`, or use `directory` to enter
73 a directory path, or `files` to enter a vector of file paths.
74 output (str): Path to a file to write results to.
75 ids (str | list[str | int]): Vector of IDs for each `text`, or a column name in `text`
76 containing IDs.
77 text_column (str): Column name in `text` containing text.
78 id_column (str): Column name in `text` containing IDs.
79 files (list[str]): Vector of file paths, as alternate entry to `text`.
80 directory (str): A directory path to search for files in, as alternate entry to `text`.
81 file_type (str): Extension of the file(s) to be read in from a directory (`txt` or `csv`).
82 encoding (str | None): Encoding of file(s) to be read in; one of the
83 [standard encodings](https://docs.python.org/3/library/codecs.html#standard-encodings).
84 If this is `None` (default), encoding will be predicted for each file, but this can
85 potentially fail, resulting in mis-encoded characters. For best (and fastest) results,
86 specify encoding.
87 return_text (bool): If `True`, will include a `text` column in the output with the
88 original text.
89 context (str): Name of the analysis context.
90 custom_context (str | bool): Name of a custom context (as listed by `receptiviti.norming`),
91 or `True` if `context` is the name of a custom context.
92 api_args (dict): Additional arguments to include in the request.
93 frameworks (str | list): One or more names of frameworks to request. Note that this
94 changes the results from the API, so it will invalidate any cached results
95 without the same set of frameworks.
96 framework_prefix (bool): If `False`, will drop framework prefix from column names.
97 If one framework is selected, will default to `False`.
98 bundle_size (int): Maximum number of texts per bundle.
99 bundle_byte_limit (float): Maximum byte size of each bundle.
100 collapse_lines (bool): If `True`, will treat files as containing single texts, and
101 collapse multiple lines.
102 retry_limit (int): Number of times to retry a failed request.
103 clear_cache (bool): If `True`, will delete the `cache` before processing.
104 request_cache (bool): If `False`, will not temporarily save raw requests for reuse
105 within a day.
106 cores (int): Number of CPU cores to use when processing multiple bundles.
107 in_memory (bool | None): If `False`, will write bundles to disc, to be loaded when
108 processed. Defaults to `True` when processing in parallel.
109 verbose (bool): If `True`, will print status messages and preserve the progress bar.
110 progress_bar (str | bool): If `False`, will not display a progress bar.
111 overwrite (bool): If `True`, will overwrite an existing `output` file.
112 text_as_paths (bool): If `True`, will explicitly mark `text` as a list of file paths.
113 Otherwise, this will be detected.
114 dotenv (bool | str): Path to a .env file to read environment variables from. By default,
115 will for a file in the current directory or `~/Documents`.
116 Passed to `readin_env` as `path`.
117 cache (bool | str): Path to a cache directory, or `True` to use the default directory.
118 cache_overwrite (bool): If `True`, will not check the cache for previously cached texts,
119 but will store results in the cache (unlike `cache = False`).
120 cache_format (str): File format of the cache, of available Arrow formats.
121 key (str): Your API key.
122 secret (str): Your API secret.
123 url (str): The URL of the API; defaults to `https://api.receptiviti.com`.
124 version (str): Version of the API; defaults to `v1`.
125 endpoint (str): Endpoint of the API; defaults to `framework`.
127 Returns:
128 Scores associated with each input text.
130 Examples:
131 ```
132 # score a single text
133 single = receptiviti.request("a text to score")
135 # score multiple texts, and write results to a file
136 multi = receptiviti.request(["first text to score", "second text"], "filename.csv")
138 # score texts in separate files
139 ## defaults to look for .txt files
140 file_results = receptiviti.request(directory = "./path/to/txt_folder")
142 ## could be .csv
143 file_results = receptiviti.request(
144 directory = "./path/to/csv_folder",
145 text_column = "text", file_type = "csv"
146 )
148 # score texts in a single file
149 results = receptiviti.request("./path/to/file.csv", text_column = "text")
150 ```
152 Cache:
153 If `cache` is specified, results for unique texts are saved in an Arrow database
154 in the cache location (`os.getenv("RECEPTIVITI_CACHE")`), and are retrieved with
155 subsequent requests. This ensures that the exact same texts are not re-sent to the API.
156 This does, however, add some processing time and disc space usage.
158 If `cache` if `True`, a default directory (`receptiviti_cache`) will be
159 looked for in the system's temporary directory (`tempfile.gettempdir()`).
161 The primary cache is checked when each bundle is processed, and existing results are
162 loaded at that time. When processing many bundles in parallel, and many results have
163 been cached, this can cause the system to freeze and potentially crash.
164 To avoid this, limit the number of cores, or disable parallel processing.
166 The `cache_format` arguments (or the `RECEPTIVITI_CACHE_FORMAT` environment variable) can be
167 used to adjust the format of the cache.
169 You can use the cache independently with
170 `pyarrow.dataset.dataset(os.getenv("RECEPTIVITI_CACHE"))`.
172 You can also set the `clear_cache` argument to `True` to clear the cache before it is used
173 again, which may be useful if the cache has gotten big, or you know new results will be
174 returned.
176 Even if a cached result exists, it will be reprocessed if it does not have all of the
177 variables of new results, but this depends on there being at least 1 uncached result. If,
178 for instance, you add a framework to your account and want to reprocess a previously
179 processed set of texts, you would need to first clear the cache.
181 Either way, duplicated texts within the same call will only be sent once.
183 The `request_cache` argument controls a more temporary cache of each bundle request. This
184 is cleared after a day. You might want to set this to `False` if a new framework becomes
185 available on your account and you want to process a set of text you re-processed recently.
187 Another temporary cache is made when `in_memory` is `False`, which is the default when
188 processing in parallel (when there is more than 1 bundle and `cores` is over 1). This is a
189 temporary directory that contains a file for each unique bundle, which is read in as needed
190 by the parallel workers.
192 Parallelization:
193 `text`s are split into bundles based on the `bundle_size` argument. Each bundle represents
194 a single request to the API, which is why they are limited to 1000 texts and a total size
195 of 10 MB. When there is more than one bundle and `cores` is greater than 1, bundles are
196 processed by multiple cores.
198 If you have texts spread across multiple files, they can be most efficiently processed in
199 parallel if each file contains a single text (potentially collapsed from multiple lines).
200 If files contain multiple texts (i.e., `collapse_lines=False`), then texts need to be
201 read in before bundling in order to ensure bundles are under the length limit.
203 If you are calling this function from a script, parallelization will involve rerunning
204 that script in each process, so anything you don't want rerun should be protected by
205 a check that `__name__` equals `"__main__"`
206 (placed within an `if __name__ == "__main__":` clause).
207 """
208 if cores > 1 and current_process().name != "MainProcess":
209 return None
210 if output is not None and os.path.isfile(output) and not overwrite:
211 msg = "`output` file already exists; use `overwrite=True` to overwrite it"
212 raise RuntimeError(msg)
213 start_time = perf_counter()
215 if dotenv:
216 readin_env(dotenv if isinstance(dotenv, str) else ".")
217 dotenv = False
219 # check norming context
220 if isinstance(custom_context, str):
221 context = custom_context
222 custom_context = True
223 if context != "written":
224 if verbose:
225 print(f"retrieving norming contexts ({perf_counter() - start_time:.4f})")
226 available_contexts: "list[str]" = norming(name_only=True, url=url, key=key, secret=secret, verbose=False)
227 if ("custom/" + context if custom_context else context) not in available_contexts:
228 msg = f"norming context {context} is not on record or is not completed"
229 raise RuntimeError(msg)
231 # check frameworks
232 if frameworks and version and "2" in version:
233 if not api_args:
234 api_args = {}
235 if isinstance(frameworks, str):
236 frameworks = [frameworks]
237 api_args["frameworks"] = [f for f in frameworks if f != "summary"]
238 if api_args and "frameworks" in api_args:
239 arg_frameworks: "list[str]" = (
240 api_args["frameworks"].split(",") if isinstance(api_args["frameworks"], str) else api_args["frameworks"]
241 )
242 available_frameworks = get_frameworks(url=url, key=key, secret=secret)
243 for f in arg_frameworks:
244 if f not in available_frameworks:
245 msg = f"requested framework is not available to your account: {f}"
246 raise RuntimeError(msg)
247 if isinstance(api_args["frameworks"], list):
248 api_args["frameworks"] = ",".join(api_args["frameworks"])
250 if isinstance(cache, str) and cache:
251 if clear_cache and os.path.exists(cache):
252 shutil.rmtree(cache, True)
253 os.makedirs(cache, exist_ok=True)
254 if not cache_format:
255 cache_format = os.getenv("RECEPTIVITI_CACHE_FORMAT", "parquet")
256 if cache_format not in ["parquet", "feather"]:
257 msg = "`cache_format` must be `parquet` or `feather`"
258 raise RuntimeError(msg)
259 else:
260 cache = ""
262 data, res, id_specified = _manage_request(
263 text=text,
264 ids=ids,
265 text_column=text_column,
266 id_column=id_column,
267 files=files,
268 directory=directory,
269 file_type=file_type,
270 encoding=encoding,
271 context=f"custom/{context}" if custom_context else context,
272 api_args=api_args,
273 bundle_size=bundle_size,
274 bundle_byte_limit=bundle_byte_limit,
275 collapse_lines=collapse_lines,
276 retry_limit=retry_limit,
277 request_cache=request_cache,
278 cores=cores,
279 in_memory=in_memory,
280 verbose=verbose,
281 progress_bar=progress_bar,
282 make_request=make_request,
283 text_as_paths=text_as_paths,
284 dotenv=dotenv,
285 cache=cache,
286 cache_overwrite=cache_overwrite,
287 cache_format=cache_format,
288 key=key,
289 secret=secret,
290 url=url,
291 version=version,
292 endpoint=endpoint,
293 )
295 # finalize
296 if res is None or not res.shape[0]:
297 msg = "no results"
298 raise RuntimeError(msg)
299 if isinstance(cache, str):
300 writer = _get_writer(cache_format)
301 schema = pyarrow.schema(
302 (
303 col,
304 (
305 pyarrow.string()
306 if res[col].dtype == "O"
307 else (
308 pyarrow.int32()
309 if col in ["summary.word_count", "summary.sentence_count"]
310 else pyarrow.float32()
311 )
312 ),
313 )
314 for col in res.columns
315 if col not in ["id", "bin", *(api_args.keys() if api_args else [])]
316 )
317 for bin_dir in glob(cache + "/bin=*/"):
318 _defragment_bin(bin_dir, cache_format, writer, schema)
319 if verbose:
320 print(f"preparing output ({perf_counter() - start_time:.4f})")
321 data.set_index("id", inplace=True)
322 res.set_index("id", inplace=True)
323 if len(res) != len(data):
324 res = res.join(data["text"])
325 data_absent = data.loc[list(set(data.index).difference(res.index))]
326 data_absent = data_absent.loc[data_absent["text"].isin(res["text"])]
327 if data.size:
328 res = res.reset_index()
329 res.set_index("text", inplace=True)
330 data_dupes = res.loc[data_absent["text"]]
331 data_dupes["id"] = data_absent.index.to_list()
332 res = pandas.concat([res, data_dupes])
333 res.reset_index(inplace=True, drop=True)
334 res.set_index("id", inplace=True)
335 res = res.join(data["text"], how="right")
336 if not return_text:
337 res.drop("text", axis=1, inplace=True)
338 res = res.reset_index()
340 if output is not None:
341 if verbose:
342 print(f"writing results to file: {output} ({perf_counter() - start_time:.4f})")
343 res.to_csv(output, index=False)
345 drops = ["custom", "bin"]
346 if not id_specified:
347 drops.append("id")
348 res.drop(
349 list({*drops}.intersection(res.columns)),
350 axis="columns",
351 inplace=True,
352 )
353 if frameworks is not None:
354 if verbose:
355 print(f"selecting frameworks ({perf_counter() - start_time:.4f})")
356 if isinstance(frameworks, str):
357 frameworks = [frameworks]
358 if len(frameworks) == 1 and framework_prefix is None:
359 framework_prefix = False
360 select = []
361 if id_specified:
362 select.append("id")
363 if return_text:
364 select.append("text")
365 select.append("text_hash")
366 res = res.filter(regex=f"^(?:{'|'.join(select + frameworks)})(?:$|\\.)")
367 if isinstance(framework_prefix, bool) and not framework_prefix:
368 prefix_pattern = re.compile("^[^.]+\\.")
369 res.columns = pandas.Index([prefix_pattern.sub("", col) for col in res.columns])
371 if verbose:
372 print(f"done ({perf_counter() - start_time:.4f})")
374 return res
377def _defragment_bin(bin_dir: str, write_format: str, writer, schema: pyarrow.Schema):
378 fragments = glob(f"{bin_dir}/*.{write_format}")
379 if len(fragments) > 1:
380 data = pyarrow.dataset.dataset(fragments, schema, format=write_format, exclude_invalid_files=True).to_table()
381 nrows = data.num_rows
382 n_chunks = max(1, ceil(nrows / 1e9))
383 rows_per_chunk = max(1, ceil(nrows / n_chunks))
384 time_id = str(ceil(time()))
385 for chunk in range(0, n_chunks, rows_per_chunk):
386 writer(data[chunk : (chunk + rows_per_chunk)], f"{bin_dir}/part-{time_id}-{chunk}.{write_format}")
387 for fragment in fragments:
388 os.unlink(fragment)