Coverage for gramex\ml.py : 59%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import six
3import json
4import inspect
5import threading
6import joblib
7import pandas as pd
8from tornado.gen import coroutine, Return, sleep
9from tornado.httpclient import AsyncHTTPClient
10from gramex.config import locate, app_log, merge, variables
12# Expose joblob.load via gramex.ml
13load = joblib.load # noqa
16class Classifier(object):
17 '''
18 :arg data DataFrame: data to train / re-train the model with
19 :arg model_class str: model class to use (default: ``sklearn.naive_bayes.BernoulliNB``)
20 :arg model_kwargs dict: kwargs to pass to model class constructor (defaults: ``{}``)
21 :arg output str: output column name (default: last column in training data)
22 :arg input list: input column names (default: all columns except ``output``)
23 :arg labels list: list of possible output values (default: unique ``output`` in training)
24 '''
26 def __init__(self, **kwargs):
27 vars(self).update(kwargs)
28 self.model_class = kwargs.get('model_class', 'sklearn.naive_bayes.BernoulliNB')
29 self.trained = False # Boolean Flag
31 def __str__(self):
32 return repr(vars(self))
34 def update_params(self, params):
35 model_keys = ('model_class', 'url', 'input', 'output', 'trained', 'query', 'model_kwargs')
36 model_params = {k: v[0] if isinstance(v, list) and k != 'input' else v
37 for k, v in params.items() if k in model_keys}
38 if model_params:
39 self.trained = params.get('trained', False)
40 vars(self).update(model_params)
42 def train(self, data):
43 '''
44 :arg data DataFrame: data to train / re-train the model with
45 :arg model_class str: model class to use (default: ``sklearn.naive_bayes.BernoulliNB``)
46 :arg model_kwargs dict: kwargs to pass to model class constructor (defaults: ``{}``)
47 :arg output str: output column name (default: last column in training data)
48 :arg input list: input column names (default: all columns except ``output``)
49 :arg labels list: list of possible output values (default: unique ``output`` in training)
51 Notes:
52 - If model has already been trained, extend the model. Else create it
53 '''
54 self.output = vars(self).get('output', data.columns[-1])
55 self.input = vars(self).get('input', list(data.columns[:-1]))
56 self.model_kwargs = vars(self).get('model_kwargs', {})
57 self.labels = vars(self).get('labels', None)
58 # If model_kwargs have changed since we trained last, re-train model.
59 if not self.trained and hasattr(self, 'model'):
60 vars(self).pop('model')
61 if not hasattr(self, 'model'): 61 ↛ 85line 61 didn't jump to line 85, because the condition on line 61 was never false
62 # Split it into input (x) and output (y)
63 x, y = data[self.input], data[self.output]
64 # Transform the data
65 from sklearn.preprocessing import StandardScaler
66 self.scaler = StandardScaler()
67 self.scaler.fit(x)
68 # Train the classifier. Partially, if possible
69 try:
70 clf = locate(self.model_class)(**self.model_kwargs)
71 except TypeError:
72 raise ValueError('{0} is not a correct model class'.format(self.model_class))
73 if self.labels and hasattr(clf, 'partial_fit'): 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true
74 try:
75 clf.partial_fit(self.scaler.transform(x),
76 y, classes=self.labels)
77 except AttributeError:
78 raise ValueError('{0} does not support partial fit'.format(self.model_class))
79 else:
80 clf.fit(self.scaler.transform(x), y)
81 self.model = clf
83 # Extend the model
84 else:
85 x, y = data[self.input], data[self.output]
86 classes = set(self.model.classes_)
87 classes |= set(y)
88 self.model.partial_fit(self.scaler.transform(x), y)
90 def predict(self, data):
91 '''
92 Return a Series that has the results of the classification of data
93 '''
94 # Convert list of lists or numpy arrays into DataFrame. Assume columns are as per input
95 if not isinstance(data, pd.DataFrame):
96 data = pd.DataFrame(data, columns=self.input)
97 # Take only trained input columns
98 return self.model.predict(self.scaler.transform(data))
100 def save(self, path):
101 '''
102 Serializes the model and associated parameters
103 '''
104 joblib.dump(self, path, compress=9)
107def _conda_r_home():
108 '''
109 Returns the R home directory for Conda R if it is installed. Else None.
111 Typically, people install Conda AND R (in any order), and use the system R
112 (rather than the conda R) by placing it before Conda in the PATH.
114 But the system R does not work with Conda rpy2. So we check if Conda R
115 exists and return its path, so that it can be used as R_HOME.
116 '''
117 try:
118 from conda.base.context import context
119 except ImportError:
120 app_log.error('Anaconda not installed. Cannot use Anaconda R')
121 return None
122 r_home = os.path.normpath(os.path.join(context.root_prefix, 'lib', 'R'))
123 if os.path.isdir(os.path.join(r_home, 'bin')): 123 ↛ 125line 123 didn't jump to line 125, because the condition on line 123 was never false
124 return r_home
125 app_log.error('Anaconda R not installed')
126 return None
129def r(code=None, path=None, rel=True, conda=True, convert=True,
130 repo='https://cran.microsoft.com/', **kwargs):
131 '''
132 Runs the R script and returns the result.
134 :arg str code: R code to execute.
135 :arg str path: R script path. Cannot be used if code is specified
136 :arg bool rel: True treats path as relative to the caller function's file
137 :arg bool conda: True overrides R_HOME to use the Conda R
138 :arg bool convert: True converts R objects to Pandas and vice versa
139 :arg str repo: CRAN repo URL
141 All other keyword arguments as passed as parameters
142 '''
143 # Use Conda R if possible
144 if conda: 144 ↛ 150line 144 didn't jump to line 150, because the condition on line 144 was never false
145 r_home = _conda_r_home()
146 if r_home: 146 ↛ 150line 146 didn't jump to line 150, because the condition on line 146 was never false
147 os.environ['R_HOME'] = r_home
149 # Import the global R session
150 try:
151 from rpy2.robjects import r, pandas2ri, globalenv
152 except ImportError:
153 app_log.error('rpy2 not installed. Run "conda install rpy2"')
154 raise
155 except RuntimeError:
156 app_log.error('Cannot find R. Set R_HOME env variable')
157 raise
159 # Set a repo so that install.packages() need not ask for one
160 r('local({r <- getOption("repos"); r["CRAN"] <- "%s"; options(repos = r)})' % repo)
162 # Activate or de-activate automatic conversion
163 # https://pandas.pydata.org/pandas-docs/version/0.22.0/r_interface.html
164 if convert: 164 ↛ 167line 164 didn't jump to line 167, because the condition on line 164 was never false
165 pandas2ri.activate()
166 else:
167 pandas2ri.deactivate()
169 # Pass all other kwargs as global environment variables
170 for key, val in kwargs.items():
171 globalenv[key] = val
173 if code and path: 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true
174 raise RuntimeError('Use r(code=) or r(path=...), not both')
175 if path:
176 # if rel=True, load path relative to parent directory
177 if rel: 177 ↛ 181line 177 didn't jump to line 181, because the condition on line 177 was never false
178 stack = inspect.getouterframes(inspect.currentframe(), 2)
179 folder = os.path.dirname(os.path.abspath(stack[1][1]))
180 path = os.path.join(folder, path)
181 result = r.source(path, chdir=True)
182 # source() returns a withVisible: $value and $visible. Use only the first
183 result = result[0]
184 else:
185 result = r(code)
187 return result
190def groupmeans(data, groups, numbers, cutoff=.01, quantile=.95, minsize=None,
191 weight=None):
192 '''
193 Yields the significant differences in average between every pair of
194 groups and numbers.
196 :arg DataFrame data: pandas.DataFrame to analyze
197 :arg list groups: category column names to group data by
198 :arg list numbers: numeric column names in to summarize data by
199 :arg float cutoff: ignore anything with prob > cutoff.
200 cutoff=None ignores significance checks, speeding it up a LOT.
201 :arg float quantile: number that represents target improvement. Defaults to .95.
202 The ``diff`` returned is the % impact of everyone moving to the 95th
203 percentile
204 :arg int minsize: each group should contain at least minsize values.
205 If minsize=None, automatically set the minimum size to
206 1% of the dataset, or 10, whichever is larger.
207 '''
208 from scipy.stats.mstats import ttest_ind
209 if minsize is None: 209 ↛ 212line 209 didn't jump to line 212, because the condition on line 209 was never false
210 minsize = max(len(data.index) // 100, 10)
212 if weight is None: 212 ↛ 215line 212 didn't jump to line 215, because the condition on line 212 was never false
213 means = data[numbers].mean()
214 else:
215 means = weighted_avg(data, numbers, weight)
216 results = []
217 for group in groups:
218 grouped = data.groupby(group, sort=False)
219 if weight is None: 219 ↛ 222line 219 didn't jump to line 222, because the condition on line 219 was never false
220 ave = grouped[numbers].mean()
221 else:
222 ave = grouped.apply(lambda v: weighted_avg(v, numbers, weight))
223 ave['#'] = sizes = grouped.size()
224 # Each group should contain at least minsize values
225 biggies = sizes[sizes >= minsize].index
226 # ... and at least 2 groups overall, to compare.
227 if len(biggies) < 2:
228 continue
229 for number in numbers:
230 if number == group: 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true
231 continue
232 sorted_cats = ave[number][biggies].dropna().sort_values()
233 if len(sorted_cats) < 2: 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true
234 continue
235 lo = data[number][grouped.groups[sorted_cats.index[0]]].values
236 hi = data[number][grouped.groups[sorted_cats.index[-1]]].values
237 _, prob = ttest_ind(
238 pd.np.ma.masked_array(lo, pd.np.isnan(lo)),
239 pd.np.ma.masked_array(hi, pd.np.isnan(hi))
240 )
241 if prob > cutoff: 241 ↛ 242line 241 didn't jump to line 242, because the condition on line 241 was never true
242 continue
243 results.append({
244 'group': group,
245 'number': number,
246 'prob': prob,
247 'gain': sorted_cats.iloc[-1] / means[number] - 1,
248 'biggies': ave.loc[biggies][number].to_dict(),
249 'means': ave[[number, '#']].sort_values(number).to_dict(),
250 })
252 results = pd.DataFrame(results)
253 if len(results) > 0:
254 results = results.set_index(['group', 'number'])
255 return results.reset_index() # Flatten multi-index.
258def weighted_avg(data, numeric_cols, weight):
259 '''
260 Computes weighted average for specificied columns
261 '''
262 sumprod = data[numeric_cols].multiply(data[weight], axis=0).sum()
263 return sumprod / data[weight].sum()
266def _google_translate(q, source, target, key):
267 import requests
268 params = {'q': q, 'target': target, 'key': key}
269 if source:
270 params['source'] = source
271 try:
272 r = requests.post('https://translation.googleapis.com/language/translate/v2', data=params)
273 except requests.RequestException:
274 return app_log.exception('Cannot connect to Google Translate')
275 response = r.json()
276 if 'error' in response:
277 return app_log.error('Google Translate API error: %s', response['error'])
278 return {
279 'q': q,
280 't': [t['translatedText'] for t in response['data']['translations']],
281 'source': [t['detectedSourceLanguage'] for t in response['data']['translations']],
282 'target': [target] * len(q),
283 }
286translate_api = {
287 'google': _google_translate
288}
289# Prevent translate cache from being accessed concurrently across threads.
290# TODO: avoid threads and use Tornado ioloop/gen instead.
291_translate_cache_lock = threading.Lock()
294def translate(*q, **kwargs):
295 '''
296 Translate strings using the Google Translate API. Example::
298 translate('Hello', 'World', source='en', target='de', key='...')
300 returns a DataFrame::
302 source target q t
303 en de Hello ...
304 en de World ...
306 The results can be cached via a ``cache={...}`` that has parameters for
307 :py:func:`gramex.data.filter`. Example::
309 translate('Hello', key='...', cache={'url': 'translate.xlsx'})
311 :arg str q: one or more strings to translate
312 :arg str source: 2-letter source language (e.g. en, fr, es, hi, cn, etc).
313 If empty or None, auto-detects source
314 :arg str target: 2-letter target language (e.g. en, fr, es, hi, cn, etc).
315 :arg str key: Google Translate API key
316 :arg dict cache: kwargs for :py:func:`gramex.data.filter`. Has keys such as
317 url (required), table (for databases), sheet_name (for Excel), etc.
319 Reference: https://cloud.google.com/translate/docs/apis
320 '''
321 import gramex.data
322 source = kwargs.pop('source', None)
323 target = kwargs.pop('target', None)
324 key = kwargs.pop('key', None)
325 cache = kwargs.pop('cache', None)
326 api = kwargs.pop('api', 'google')
327 if cache is not None:
328 if not isinstance(cache, dict): 328 ↛ 329line 328 didn't jump to line 329, because the condition on line 328 was never true
329 raise ValueError('cache= must be a FormHandler dict config, not %r' % cache)
331 # Store data in cache with fixed columns: source, target, q, t
332 result = pd.DataFrame(columns=['source', 'target', 'q', 't'])
333 if not q: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true
334 return result
335 original_q = q
337 # Fetch from cache, if any
338 if cache:
339 try:
340 args = {'q': q, 'target': [target] * len(q)}
341 if source:
342 args['source'] = [source] * len(q)
343 with _translate_cache_lock:
344 result = gramex.data.filter(args=args, **cache)
345 except Exception:
346 app_log.exception('Cannot query %r in translate cache: %r', args, dict(cache))
347 # Remove already cached results from q
348 q = [v for v in q if v not in set(result.get('q', []))]
350 if len(q):
351 new_data = translate_api[api](q, source, target, key)
352 if new_data is not None: 352 ↛ 359line 352 didn't jump to line 359, because the condition on line 352 was never false
353 result = result.append(pd.DataFrame(new_data), sort=False)
354 if cache:
355 with _translate_cache_lock:
356 gramex.data.insert(id=['source', 'target', 'q'], args=new_data, **cache)
358 # Sort results by q
359 result['order'] = result['q'].map(original_q.index)
360 result.sort_values('order', inplace=True)
361 result.drop_duplicates(subset=['q'], inplace=True)
362 del result['order']
364 return result
367@coroutine
368def translater(handler, source='en', target='nl', key=None, cache=None, api='google'):
369 args = handler.argparse(
370 q={'nargs': '*', 'default': []},
371 source={'default': source},
372 target={'default': target}
373 )
374 import gramex
375 result = yield gramex.service.threadpool.submit(
376 translate, *args.q, source=args.source, target=args.target, key=key, cache=cache, api=api)
378 # TODO: support gramex.data.download features
379 handler.set_header('Content-Type', 'application/json; encoding="UTF-8"')
380 raise Return(result.to_json(orient='records'))
383_languagetool = {
384 'defaults': {k: v for k, v in variables.items() if k.startswith('LT_')},
385 'installed': os.path.isdir(variables['LT_CWD'])
386}
389@coroutine
390def languagetool(handler, *args, **kwargs):
391 import gramex
392 merge(kwargs, _languagetool['defaults'], mode='setdefault')
393 yield gramex.service.threadpool.submit(languagetool_download)
394 if not handler:
395 lang = kwargs.get('lang', 'en-us')
396 q = kwargs.get('q', '')
397 else:
398 lang = handler.get_argument('lang', 'en-us')
399 q = handler.get_argument('q', '')
400 result = yield languagetoolrequest(q, lang, **kwargs)
401 errors = json.loads(result.decode('utf8'))['matches']
402 if errors:
403 result = {
404 "errors": errors,
405 }
406 corrected = list(q)
407 d_offset = 0 # difference in the offset caused by the correction
408 for error in errors:
409 # only accept the first replacement for an error
410 correction = error['replacements'][0]['value']
411 offset, limit = error['offset'], error['length']
412 offset += d_offset
413 del corrected[offset:(offset + limit)]
414 for i, char in enumerate(correction):
415 corrected.insert(offset + i, char)
416 d_offset += len(correction) - limit
417 result['correction'] = "".join(corrected)
418 result = json.dumps(result)
419 raise Return(result)
422@coroutine
423def languagetoolrequest(text, lang='en-us', **kwargs):
424 """Check grammar by making a request to the LanguageTool server.
426 Parameters
427 ----------
428 text : str
429 Text to check
430 lang : str, optional
431 Language. See a list of supported languages here: https://languagetool.org/api/v2/languages
432 """
433 client = AsyncHTTPClient()
434 url = kwargs['LT_URL'].format(**kwargs)
435 query = six.moves.urllib_parse.urlencode({'language': lang, 'text': text})
436 url = url + query
437 tries = 2 # See: https://github.com/gramener/gramex/pull/125#discussion_r266200480
438 while tries:
439 try:
440 result = yield client.fetch(url)
441 tries = 0
442 except ConnectionRefusedError:
443 # Start languagetool
444 from gramex.cache import daemon
445 cmd = [p.format(**kwargs) for p in kwargs['LT_CMD']]
446 app_log.info('Starting: %s', ' '.join(cmd))
447 if 'proc' not in _languagetool:
448 import re
449 _languagetool['proc'] = daemon(
450 cmd, cwd=kwargs['LT_CWD'],
451 first_line=re.compile(r"Server started\s*$"),
452 stream=True, timeout=5,
453 buffer_size=512
454 )
455 try:
456 result = yield client.fetch(url)
457 tries = 0
458 except ConnectionRefusedError:
459 yield sleep(1)
460 tries -= 1
461 raise Return(result.body)
464def languagetool_download():
465 if _languagetool['installed']: 465 ↛ 466line 465 didn't jump to line 466, because the condition on line 465 was never true
466 return
467 import requests, zipfile, io # noqa
468 target = _languagetool['defaults']['LT_TARGET']
469 if not os.path.isdir(target):
470 os.makedirs(target)
471 src = _languagetool['defaults']['LT_SRC'].format(**_languagetool['defaults'])
472 app_log.info('Downloading languagetools from %s', src)
473 stream = io.BytesIO(requests.get(src).content)
474 app_log.info('Unzipping languagetools to %s', target)
475 zipfile.ZipFile(stream).extractall(target)
476 _languagetool['installed'] = True
479# Gramex 1.48 spelt translater as translator. Accept both spellings.
480translator = translater