Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import six 

3import json 

4import inspect 

5import threading 

6import joblib 

7import pandas as pd 

8from tornado.gen import coroutine, Return, sleep 

9from tornado.httpclient import AsyncHTTPClient 

10from gramex.config import locate, app_log, merge, variables 

11 

12# Expose joblob.load via gramex.ml 

13load = joblib.load # noqa 

14 

15 

16class Classifier(object): 

17 ''' 

18 :arg data DataFrame: data to train / re-train the model with 

19 :arg model_class str: model class to use (default: ``sklearn.naive_bayes.BernoulliNB``) 

20 :arg model_kwargs dict: kwargs to pass to model class constructor (defaults: ``{}``) 

21 :arg output str: output column name (default: last column in training data) 

22 :arg input list: input column names (default: all columns except ``output``) 

23 :arg labels list: list of possible output values (default: unique ``output`` in training) 

24 ''' 

25 

26 def __init__(self, **kwargs): 

27 vars(self).update(kwargs) 

28 self.model_class = kwargs.get('model_class', 'sklearn.naive_bayes.BernoulliNB') 

29 self.trained = False # Boolean Flag 

30 

31 def __str__(self): 

32 return repr(vars(self)) 

33 

34 def update_params(self, params): 

35 model_keys = ('model_class', 'url', 'input', 'output', 'trained', 'query', 'model_kwargs') 

36 model_params = {k: v[0] if isinstance(v, list) and k != 'input' else v 

37 for k, v in params.items() if k in model_keys} 

38 if model_params: 

39 self.trained = params.get('trained', False) 

40 vars(self).update(model_params) 

41 

42 def train(self, data): 

43 ''' 

44 :arg data DataFrame: data to train / re-train the model with 

45 :arg model_class str: model class to use (default: ``sklearn.naive_bayes.BernoulliNB``) 

46 :arg model_kwargs dict: kwargs to pass to model class constructor (defaults: ``{}``) 

47 :arg output str: output column name (default: last column in training data) 

48 :arg input list: input column names (default: all columns except ``output``) 

49 :arg labels list: list of possible output values (default: unique ``output`` in training) 

50 

51 Notes: 

52 - If model has already been trained, extend the model. Else create it 

53 ''' 

54 self.output = vars(self).get('output', data.columns[-1]) 

55 self.input = vars(self).get('input', list(data.columns[:-1])) 

56 self.model_kwargs = vars(self).get('model_kwargs', {}) 

57 self.labels = vars(self).get('labels', None) 

58 # If model_kwargs have changed since we trained last, re-train model. 

59 if not self.trained and hasattr(self, 'model'): 

60 vars(self).pop('model') 

61 if not hasattr(self, 'model'): 61 ↛ 85line 61 didn't jump to line 85, because the condition on line 61 was never false

62 # Split it into input (x) and output (y) 

63 x, y = data[self.input], data[self.output] 

64 # Transform the data 

65 from sklearn.preprocessing import StandardScaler 

66 self.scaler = StandardScaler() 

67 self.scaler.fit(x) 

68 # Train the classifier. Partially, if possible 

69 try: 

70 clf = locate(self.model_class)(**self.model_kwargs) 

71 except TypeError: 

72 raise ValueError('{0} is not a correct model class'.format(self.model_class)) 

73 if self.labels and hasattr(clf, 'partial_fit'): 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true

74 try: 

75 clf.partial_fit(self.scaler.transform(x), 

76 y, classes=self.labels) 

77 except AttributeError: 

78 raise ValueError('{0} does not support partial fit'.format(self.model_class)) 

79 else: 

80 clf.fit(self.scaler.transform(x), y) 

81 self.model = clf 

82 

83 # Extend the model 

84 else: 

85 x, y = data[self.input], data[self.output] 

86 classes = set(self.model.classes_) 

87 classes |= set(y) 

88 self.model.partial_fit(self.scaler.transform(x), y) 

89 

90 def predict(self, data): 

91 ''' 

92 Return a Series that has the results of the classification of data 

93 ''' 

94 # Convert list of lists or numpy arrays into DataFrame. Assume columns are as per input 

95 if not isinstance(data, pd.DataFrame): 

96 data = pd.DataFrame(data, columns=self.input) 

97 # Take only trained input columns 

98 return self.model.predict(self.scaler.transform(data)) 

99 

100 def save(self, path): 

101 ''' 

102 Serializes the model and associated parameters 

103 ''' 

104 joblib.dump(self, path, compress=9) 

105 

106 

107def _conda_r_home(): 

108 ''' 

109 Returns the R home directory for Conda R if it is installed. Else None. 

110 

111 Typically, people install Conda AND R (in any order), and use the system R 

112 (rather than the conda R) by placing it before Conda in the PATH. 

113 

114 But the system R does not work with Conda rpy2. So we check if Conda R 

115 exists and return its path, so that it can be used as R_HOME. 

116 ''' 

117 try: 

118 from conda.base.context import context 

119 except ImportError: 

120 app_log.error('Anaconda not installed. Cannot use Anaconda R') 

121 return None 

122 r_home = os.path.normpath(os.path.join(context.root_prefix, 'lib', 'R')) 

123 if os.path.isdir(os.path.join(r_home, 'bin')): 123 ↛ 125line 123 didn't jump to line 125, because the condition on line 123 was never false

124 return r_home 

125 app_log.error('Anaconda R not installed') 

126 return None 

127 

128 

129def r(code=None, path=None, rel=True, conda=True, convert=True, 

130 repo='https://cran.microsoft.com/', **kwargs): 

131 ''' 

132 Runs the R script and returns the result. 

133 

134 :arg str code: R code to execute. 

135 :arg str path: R script path. Cannot be used if code is specified 

136 :arg bool rel: True treats path as relative to the caller function's file 

137 :arg bool conda: True overrides R_HOME to use the Conda R 

138 :arg bool convert: True converts R objects to Pandas and vice versa 

139 :arg str repo: CRAN repo URL 

140 

141 All other keyword arguments as passed as parameters 

142 ''' 

143 # Use Conda R if possible 

144 if conda: 144 ↛ 150line 144 didn't jump to line 150, because the condition on line 144 was never false

145 r_home = _conda_r_home() 

146 if r_home: 146 ↛ 150line 146 didn't jump to line 150, because the condition on line 146 was never false

147 os.environ['R_HOME'] = r_home 

148 

149 # Import the global R session 

150 try: 

151 from rpy2.robjects import r, pandas2ri, globalenv 

152 except ImportError: 

153 app_log.error('rpy2 not installed. Run "conda install rpy2"') 

154 raise 

155 except RuntimeError: 

156 app_log.error('Cannot find R. Set R_HOME env variable') 

157 raise 

158 

159 # Set a repo so that install.packages() need not ask for one 

160 r('local({r <- getOption("repos"); r["CRAN"] <- "%s"; options(repos = r)})' % repo) 

161 

162 # Activate or de-activate automatic conversion 

163 # https://pandas.pydata.org/pandas-docs/version/0.22.0/r_interface.html 

164 if convert: 164 ↛ 167line 164 didn't jump to line 167, because the condition on line 164 was never false

165 pandas2ri.activate() 

166 else: 

167 pandas2ri.deactivate() 

168 

169 # Pass all other kwargs as global environment variables 

170 for key, val in kwargs.items(): 

171 globalenv[key] = val 

172 

173 if code and path: 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 raise RuntimeError('Use r(code=) or r(path=...), not both') 

175 if path: 

176 # if rel=True, load path relative to parent directory 

177 if rel: 177 ↛ 181line 177 didn't jump to line 181, because the condition on line 177 was never false

178 stack = inspect.getouterframes(inspect.currentframe(), 2) 

179 folder = os.path.dirname(os.path.abspath(stack[1][1])) 

180 path = os.path.join(folder, path) 

181 result = r.source(path, chdir=True) 

182 # source() returns a withVisible: $value and $visible. Use only the first 

183 result = result[0] 

184 else: 

185 result = r(code) 

186 

187 return result 

188 

189 

190def groupmeans(data, groups, numbers, cutoff=.01, quantile=.95, minsize=None, 

191 weight=None): 

192 ''' 

193 Yields the significant differences in average between every pair of 

194 groups and numbers. 

195 

196 :arg DataFrame data: pandas.DataFrame to analyze 

197 :arg list groups: category column names to group data by 

198 :arg list numbers: numeric column names in to summarize data by 

199 :arg float cutoff: ignore anything with prob > cutoff. 

200 cutoff=None ignores significance checks, speeding it up a LOT. 

201 :arg float quantile: number that represents target improvement. Defaults to .95. 

202 The ``diff`` returned is the % impact of everyone moving to the 95th 

203 percentile 

204 :arg int minsize: each group should contain at least minsize values. 

205 If minsize=None, automatically set the minimum size to 

206 1% of the dataset, or 10, whichever is larger. 

207 ''' 

208 from scipy.stats.mstats import ttest_ind 

209 if minsize is None: 209 ↛ 212line 209 didn't jump to line 212, because the condition on line 209 was never false

210 minsize = max(len(data.index) // 100, 10) 

211 

212 if weight is None: 212 ↛ 215line 212 didn't jump to line 215, because the condition on line 212 was never false

213 means = data[numbers].mean() 

214 else: 

215 means = weighted_avg(data, numbers, weight) 

216 results = [] 

217 for group in groups: 

218 grouped = data.groupby(group, sort=False) 

219 if weight is None: 219 ↛ 222line 219 didn't jump to line 222, because the condition on line 219 was never false

220 ave = grouped[numbers].mean() 

221 else: 

222 ave = grouped.apply(lambda v: weighted_avg(v, numbers, weight)) 

223 ave['#'] = sizes = grouped.size() 

224 # Each group should contain at least minsize values 

225 biggies = sizes[sizes >= minsize].index 

226 # ... and at least 2 groups overall, to compare. 

227 if len(biggies) < 2: 

228 continue 

229 for number in numbers: 

230 if number == group: 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true

231 continue 

232 sorted_cats = ave[number][biggies].dropna().sort_values() 

233 if len(sorted_cats) < 2: 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true

234 continue 

235 lo = data[number][grouped.groups[sorted_cats.index[0]]].values 

236 hi = data[number][grouped.groups[sorted_cats.index[-1]]].values 

237 _, prob = ttest_ind( 

238 pd.np.ma.masked_array(lo, pd.np.isnan(lo)), 

239 pd.np.ma.masked_array(hi, pd.np.isnan(hi)) 

240 ) 

241 if prob > cutoff: 241 ↛ 242line 241 didn't jump to line 242, because the condition on line 241 was never true

242 continue 

243 results.append({ 

244 'group': group, 

245 'number': number, 

246 'prob': prob, 

247 'gain': sorted_cats.iloc[-1] / means[number] - 1, 

248 'biggies': ave.loc[biggies][number].to_dict(), 

249 'means': ave[[number, '#']].sort_values(number).to_dict(), 

250 }) 

251 

252 results = pd.DataFrame(results) 

253 if len(results) > 0: 

254 results = results.set_index(['group', 'number']) 

255 return results.reset_index() # Flatten multi-index. 

256 

257 

258def weighted_avg(data, numeric_cols, weight): 

259 ''' 

260 Computes weighted average for specificied columns 

261 ''' 

262 sumprod = data[numeric_cols].multiply(data[weight], axis=0).sum() 

263 return sumprod / data[weight].sum() 

264 

265 

266def _google_translate(q, source, target, key): 

267 import requests 

268 params = {'q': q, 'target': target, 'key': key} 

269 if source: 

270 params['source'] = source 

271 try: 

272 r = requests.post('https://translation.googleapis.com/language/translate/v2', data=params) 

273 except requests.RequestException: 

274 return app_log.exception('Cannot connect to Google Translate') 

275 response = r.json() 

276 if 'error' in response: 

277 return app_log.error('Google Translate API error: %s', response['error']) 

278 return { 

279 'q': q, 

280 't': [t['translatedText'] for t in response['data']['translations']], 

281 'source': [t['detectedSourceLanguage'] for t in response['data']['translations']], 

282 'target': [target] * len(q), 

283 } 

284 

285 

286translate_api = { 

287 'google': _google_translate 

288} 

289# Prevent translate cache from being accessed concurrently across threads. 

290# TODO: avoid threads and use Tornado ioloop/gen instead. 

291_translate_cache_lock = threading.Lock() 

292 

293 

294def translate(*q, **kwargs): 

295 ''' 

296 Translate strings using the Google Translate API. Example:: 

297 

298 translate('Hello', 'World', source='en', target='de', key='...') 

299 

300 returns a DataFrame:: 

301 

302 source target q t 

303 en de Hello ... 

304 en de World ... 

305 

306 The results can be cached via a ``cache={...}`` that has parameters for 

307 :py:func:`gramex.data.filter`. Example:: 

308 

309 translate('Hello', key='...', cache={'url': 'translate.xlsx'}) 

310 

311 :arg str q: one or more strings to translate 

312 :arg str source: 2-letter source language (e.g. en, fr, es, hi, cn, etc). 

313 If empty or None, auto-detects source 

314 :arg str target: 2-letter target language (e.g. en, fr, es, hi, cn, etc). 

315 :arg str key: Google Translate API key 

316 :arg dict cache: kwargs for :py:func:`gramex.data.filter`. Has keys such as 

317 url (required), table (for databases), sheet_name (for Excel), etc. 

318 

319 Reference: https://cloud.google.com/translate/docs/apis 

320 ''' 

321 import gramex.data 

322 source = kwargs.pop('source', None) 

323 target = kwargs.pop('target', None) 

324 key = kwargs.pop('key', None) 

325 cache = kwargs.pop('cache', None) 

326 api = kwargs.pop('api', 'google') 

327 if cache is not None: 

328 if not isinstance(cache, dict): 328 ↛ 329line 328 didn't jump to line 329, because the condition on line 328 was never true

329 raise ValueError('cache= must be a FormHandler dict config, not %r' % cache) 

330 

331 # Store data in cache with fixed columns: source, target, q, t 

332 result = pd.DataFrame(columns=['source', 'target', 'q', 't']) 

333 if not q: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true

334 return result 

335 original_q = q 

336 

337 # Fetch from cache, if any 

338 if cache: 

339 try: 

340 args = {'q': q, 'target': [target] * len(q)} 

341 if source: 

342 args['source'] = [source] * len(q) 

343 with _translate_cache_lock: 

344 result = gramex.data.filter(args=args, **cache) 

345 except Exception: 

346 app_log.exception('Cannot query %r in translate cache: %r', args, dict(cache)) 

347 # Remove already cached results from q 

348 q = [v for v in q if v not in set(result.get('q', []))] 

349 

350 if len(q): 

351 new_data = translate_api[api](q, source, target, key) 

352 if new_data is not None: 352 ↛ 359line 352 didn't jump to line 359, because the condition on line 352 was never false

353 result = result.append(pd.DataFrame(new_data), sort=False) 

354 if cache: 

355 with _translate_cache_lock: 

356 gramex.data.insert(id=['source', 'target', 'q'], args=new_data, **cache) 

357 

358 # Sort results by q 

359 result['order'] = result['q'].map(original_q.index) 

360 result.sort_values('order', inplace=True) 

361 result.drop_duplicates(subset=['q'], inplace=True) 

362 del result['order'] 

363 

364 return result 

365 

366 

367@coroutine 

368def translater(handler, source='en', target='nl', key=None, cache=None, api='google'): 

369 args = handler.argparse( 

370 q={'nargs': '*', 'default': []}, 

371 source={'default': source}, 

372 target={'default': target} 

373 ) 

374 import gramex 

375 result = yield gramex.service.threadpool.submit( 

376 translate, *args.q, source=args.source, target=args.target, key=key, cache=cache, api=api) 

377 

378 # TODO: support gramex.data.download features 

379 handler.set_header('Content-Type', 'application/json; encoding="UTF-8"') 

380 raise Return(result.to_json(orient='records')) 

381 

382 

383_languagetool = { 

384 'defaults': {k: v for k, v in variables.items() if k.startswith('LT_')}, 

385 'installed': os.path.isdir(variables['LT_CWD']) 

386} 

387 

388 

389@coroutine 

390def languagetool(handler, *args, **kwargs): 

391 import gramex 

392 merge(kwargs, _languagetool['defaults'], mode='setdefault') 

393 yield gramex.service.threadpool.submit(languagetool_download) 

394 if not handler: 

395 lang = kwargs.get('lang', 'en-us') 

396 q = kwargs.get('q', '') 

397 else: 

398 lang = handler.get_argument('lang', 'en-us') 

399 q = handler.get_argument('q', '') 

400 result = yield languagetoolrequest(q, lang, **kwargs) 

401 errors = json.loads(result.decode('utf8'))['matches'] 

402 if errors: 

403 result = { 

404 "errors": errors, 

405 } 

406 corrected = list(q) 

407 d_offset = 0 # difference in the offset caused by the correction 

408 for error in errors: 

409 # only accept the first replacement for an error 

410 correction = error['replacements'][0]['value'] 

411 offset, limit = error['offset'], error['length'] 

412 offset += d_offset 

413 del corrected[offset:(offset + limit)] 

414 for i, char in enumerate(correction): 

415 corrected.insert(offset + i, char) 

416 d_offset += len(correction) - limit 

417 result['correction'] = "".join(corrected) 

418 result = json.dumps(result) 

419 raise Return(result) 

420 

421 

422@coroutine 

423def languagetoolrequest(text, lang='en-us', **kwargs): 

424 """Check grammar by making a request to the LanguageTool server. 

425 

426 Parameters 

427 ---------- 

428 text : str 

429 Text to check 

430 lang : str, optional 

431 Language. See a list of supported languages here: https://languagetool.org/api/v2/languages 

432 """ 

433 client = AsyncHTTPClient() 

434 url = kwargs['LT_URL'].format(**kwargs) 

435 query = six.moves.urllib_parse.urlencode({'language': lang, 'text': text}) 

436 url = url + query 

437 tries = 2 # See: https://github.com/gramener/gramex/pull/125#discussion_r266200480 

438 while tries: 

439 try: 

440 result = yield client.fetch(url) 

441 tries = 0 

442 except ConnectionRefusedError: 

443 # Start languagetool 

444 from gramex.cache import daemon 

445 cmd = [p.format(**kwargs) for p in kwargs['LT_CMD']] 

446 app_log.info('Starting: %s', ' '.join(cmd)) 

447 if 'proc' not in _languagetool: 

448 import re 

449 _languagetool['proc'] = daemon( 

450 cmd, cwd=kwargs['LT_CWD'], 

451 first_line=re.compile(r"Server started\s*$"), 

452 stream=True, timeout=5, 

453 buffer_size=512 

454 ) 

455 try: 

456 result = yield client.fetch(url) 

457 tries = 0 

458 except ConnectionRefusedError: 

459 yield sleep(1) 

460 tries -= 1 

461 raise Return(result.body) 

462 

463 

464def languagetool_download(): 

465 if _languagetool['installed']: 465 ↛ 466line 465 didn't jump to line 466, because the condition on line 465 was never true

466 return 

467 import requests, zipfile, io # noqa 

468 target = _languagetool['defaults']['LT_TARGET'] 

469 if not os.path.isdir(target): 

470 os.makedirs(target) 

471 src = _languagetool['defaults']['LT_SRC'].format(**_languagetool['defaults']) 

472 app_log.info('Downloading languagetools from %s', src) 

473 stream = io.BytesIO(requests.get(src).content) 

474 app_log.info('Unzipping languagetools to %s', target) 

475 zipfile.ZipFile(stream).extractall(target) 

476 _languagetool['installed'] = True 

477 

478 

479# Gramex 1.48 spelt translater as translator. Accept both spellings. 

480translator = translater