Coverage for gramex\handlers\modelhandler.py : 78%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import json
3import gramex.ml
4import pandas as pd
5import gramex.cache
6import gramex.data
7from gramex.handlers import BaseHandler
8import tornado.escape
11class ModelHandler(BaseHandler):
12 '''
13 Allows users to create API endpoints to train/test models exposed through Scikit-Learn.
14 TODO: support Scikit-Learn Pipelines for data transformations.
15 '''
16 @classmethod
17 def setup(cls, path, **kwargs):
18 super(ModelHandler, cls).setup(**kwargs)
19 cls.path = path
21 def prepare(self):
22 '''
23 Gets called automatically at the beginning of every request.
24 takes model name from request path and creates the pickle file path.
25 Also merges the request body and the url query args.
26 url query args have precedence over request body in case both exist.
27 Expects multi-row paramets to be formatted as the output of handler.argparse.
28 '''
29 self.set_header('Content-Type', 'application/json; charset=utf-8')
30 self.pickle_file_path = os.path.join(
31 self.path, self.path_args[0] + '.pkl')
32 self.request_body = {}
33 if self.request.body:
34 self.request_body = tornado.escape.json_decode(self.request.body)
35 if self.args: 35 ↛ 36line 35 didn't jump to line 36, because the condition on line 35 was never true
36 self.request_body.update(self.args)
37 url = self.request_body.get('url', '')
38 if url and gramex.data.get_engine(url) == 'file':
39 self.request_body['url'] = os.path.join(self.path, os.path.split(url)[-1])
41 def get_data_flag(self):
42 '''
43 Return a True if the request is made to /model/name/data.
44 '''
45 if len(self.path_args) > 1 and self.path_args[1] == 'data':
46 return True
48 def get(self, *path_args):
49 '''
50 Request sent to model/name with no args returns model information,
51 (that can be changed via PUT/POST).
52 Request to model/name with args will accept model input and produce predictions.
53 Request to model/name/data will return the training data specified in model.url,
54 this should accept most formhandler flags and filters as well.
55 '''
56 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load)
57 if self.get_data_flag():
58 file_kwargs = self.listify(['engine', 'url', 'ext', 'table', 'query', 'id'])
59 _format = file_kwargs.pop('_format', ['json'])[0]
60 # TODO: Add Support for formhandler filters/limit/sorting/groupby
61 data = gramex.data.filter(model.url, **file_kwargs)
62 self.write(gramex.data.download(data, format=_format, **file_kwargs))
63 return
64 # If no model columns are passed, return model info
65 if not vars(model).get('input', '') or not any(col in self.args for col in model.input): 65 ↛ 70line 65 didn't jump to line 70, because the condition on line 65 was never false
66 model_info = {k: v for k, v in vars(model).items()
67 if k not in ('model', 'scaler')}
68 self.write(json.dumps(model_info, indent=4))
69 return
70 self._predict(model)
72 def put(self, *path_args, **path_kwargs):
73 '''
74 Request to /model/name/ with no params will create a blank model.
75 Request to /model/name/ with args will interpret as model paramters.
76 Set Model-Retrain: true in headers to either train a model from scratch or extend it.
77 To Extend a trained model, don't update the parameters and send Model-Retrain in headers.
78 Request to /model/name/data with args will update the training data,
79 doesn't currently work on DF's thanks to the gramex.data bug.
80 '''
81 try:
82 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load)
83 except EnvironmentError: # noqa
84 model = gramex.ml.Classifier(**self.request_body)
85 if self.get_data_flag(): 85 ↛ 86line 85 didn't jump to line 86, because the condition on line 85 was never true
86 file_kwargs = self.listify(model.input + [model.output] + ['id'])
87 gramex.data.update(model.url, args=file_kwargs, id=file_kwargs['id'])
88 else:
89 if not self._train(model):
90 model.save(self.pickle_file_path)
92 def _predict(self, model):
93 '''Helper function for model.train.'''
94 params = self.listify(model.input)
95 if hasattr(model, 'model') and model.trained:
96 data = pd.DataFrame(params)
97 data = data[model.input]
98 data['result'] = model.predict(data)
99 self.write(data.to_json(orient='records'))
100 elif params: 100 ↛ 101line 100 didn't jump to line 101, because the condition on line 100 was never true
101 raise AttributeError('model not trained')
102 else:
103 return
105 def post(self, *path_args, **path_kwargs):
106 '''
107 Request to /model/name/ with Model-Retrain: true in the headers will,
108 attempt to update model parameters and retrain/extend the model.
109 Request to /model/name/ with model input as body/query args and no Model-Retrain,
110 in headers will return predictions.
111 Request to /model/name/data lets people add rows the test data.
112 '''
113 # load model object - if it doesn't exist, send a response asking to create the model
114 try:
115 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load)
116 except EnvironmentError: # noqa
117 # Log error
118 self.write({'Error': 'Please Send PUT Request, model does not exist'})
119 raise EnvironmentError # noqa
120 if self.get_data_flag(): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true
121 file_kwargs = self.listify(model.input + [model.output])
122 gramex.data.insert(model.url, args=file_kwargs)
123 else:
124 # If /data/ is not path_args[1] then post is sending a predict request
125 if self._train(model): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true
126 return
127 self._predict(model)
129 def delete(self, *path_args):
130 '''
131 Request to /model/name/ will delete the trained model.
132 Request to /model/name/data needs id and will delete rows from the training data.
133 '''
134 if self.get_data_flag(): 134 ↛ 135line 134 didn't jump to line 135, because the condition on line 134 was never true
135 file_kwargs = self.listify(['id'])
136 try:
137 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load)
138 except EnvironmentError: # noqa
139 self.write(
140 {'Error': 'Please Send PUT Request, model does not exist'})
141 raise EnvironmentError # noqa
142 gramex.data.delete(model.url, args=file_kwargs, id=file_kwargs['id'])
143 return
144 if os.path.exists(self.pickle_file_path): 144 ↛ exitline 144 didn't return from function 'delete', because the condition on line 144 was never false
145 os.unlink(self.pickle_file_path)
147 def _train(self, model):
148 ''' Looks for Model-Retrain in Request Headers,
149 trains a model and pickles it.
150 '''
151 # Update model parameters
152 model.update_params(self.request_body)
153 if 'Model-Retrain' in self.request.headers:
154 # Pass non model kwargs to gramex.data.filter
155 try:
156 data = gramex.data.filter(
157 model.url,
158 args=self.listify(['engine', 'url', 'ext', 'table', 'query', 'id']))
159 except AttributeError:
160 raise AttributeError('Model does not have a url')
161 # Train the model.
162 model.train(data)
163 model.trained = True
164 model.save(self.pickle_file_path)
165 return True
167 def listify(self, checklst):
168 ''' Some functions in data.py expect list values, so creates them.
169 checklst is list-like which contains the selected values to be returned.
170 '''
171 return {
172 k: [v] if not isinstance(v, list) else v
173 for k, v in self.request_body.items()
174 if k in checklst
175 }