Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import json 

3import gramex.ml 

4import pandas as pd 

5import gramex.cache 

6import gramex.data 

7from gramex.handlers import BaseHandler 

8import tornado.escape 

9 

10 

11class ModelHandler(BaseHandler): 

12 ''' 

13 Allows users to create API endpoints to train/test models exposed through Scikit-Learn. 

14 TODO: support Scikit-Learn Pipelines for data transformations. 

15 ''' 

16 @classmethod 

17 def setup(cls, path, **kwargs): 

18 super(ModelHandler, cls).setup(**kwargs) 

19 cls.path = path 

20 

21 def prepare(self): 

22 ''' 

23 Gets called automatically at the beginning of every request. 

24 takes model name from request path and creates the pickle file path. 

25 Also merges the request body and the url query args. 

26 url query args have precedence over request body in case both exist. 

27 Expects multi-row paramets to be formatted as the output of handler.argparse. 

28 ''' 

29 self.set_header('Content-Type', 'application/json; charset=utf-8') 

30 self.pickle_file_path = os.path.join( 

31 self.path, self.path_args[0] + '.pkl') 

32 self.request_body = {} 

33 if self.request.body: 

34 self.request_body = tornado.escape.json_decode(self.request.body) 

35 if self.args: 35 ↛ 36line 35 didn't jump to line 36, because the condition on line 35 was never true

36 self.request_body.update(self.args) 

37 url = self.request_body.get('url', '') 

38 if url and gramex.data.get_engine(url) == 'file': 

39 self.request_body['url'] = os.path.join(self.path, os.path.split(url)[-1]) 

40 

41 def get_data_flag(self): 

42 ''' 

43 Return a True if the request is made to /model/name/data. 

44 ''' 

45 if len(self.path_args) > 1 and self.path_args[1] == 'data': 

46 return True 

47 

48 def get(self, *path_args): 

49 ''' 

50 Request sent to model/name with no args returns model information, 

51 (that can be changed via PUT/POST). 

52 Request to model/name with args will accept model input and produce predictions. 

53 Request to model/name/data will return the training data specified in model.url, 

54 this should accept most formhandler flags and filters as well. 

55 ''' 

56 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load) 

57 if self.get_data_flag(): 

58 file_kwargs = self.listify(['engine', 'url', 'ext', 'table', 'query', 'id']) 

59 _format = file_kwargs.pop('_format', ['json'])[0] 

60 # TODO: Add Support for formhandler filters/limit/sorting/groupby 

61 data = gramex.data.filter(model.url, **file_kwargs) 

62 self.write(gramex.data.download(data, format=_format, **file_kwargs)) 

63 return 

64 # If no model columns are passed, return model info 

65 if not vars(model).get('input', '') or not any(col in self.args for col in model.input): 65 ↛ 70line 65 didn't jump to line 70, because the condition on line 65 was never false

66 model_info = {k: v for k, v in vars(model).items() 

67 if k not in ('model', 'scaler')} 

68 self.write(json.dumps(model_info, indent=4)) 

69 return 

70 self._predict(model) 

71 

72 def put(self, *path_args, **path_kwargs): 

73 ''' 

74 Request to /model/name/ with no params will create a blank model. 

75 Request to /model/name/ with args will interpret as model paramters. 

76 Set Model-Retrain: true in headers to either train a model from scratch or extend it. 

77 To Extend a trained model, don't update the parameters and send Model-Retrain in headers. 

78 Request to /model/name/data with args will update the training data, 

79 doesn't currently work on DF's thanks to the gramex.data bug. 

80 ''' 

81 try: 

82 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load) 

83 except EnvironmentError: # noqa 

84 model = gramex.ml.Classifier(**self.request_body) 

85 if self.get_data_flag(): 85 ↛ 86line 85 didn't jump to line 86, because the condition on line 85 was never true

86 file_kwargs = self.listify(model.input + [model.output] + ['id']) 

87 gramex.data.update(model.url, args=file_kwargs, id=file_kwargs['id']) 

88 else: 

89 if not self._train(model): 

90 model.save(self.pickle_file_path) 

91 

92 def _predict(self, model): 

93 '''Helper function for model.train.''' 

94 params = self.listify(model.input) 

95 if hasattr(model, 'model') and model.trained: 

96 data = pd.DataFrame(params) 

97 data = data[model.input] 

98 data['result'] = model.predict(data) 

99 self.write(data.to_json(orient='records')) 

100 elif params: 100 ↛ 101line 100 didn't jump to line 101, because the condition on line 100 was never true

101 raise AttributeError('model not trained') 

102 else: 

103 return 

104 

105 def post(self, *path_args, **path_kwargs): 

106 ''' 

107 Request to /model/name/ with Model-Retrain: true in the headers will, 

108 attempt to update model parameters and retrain/extend the model. 

109 Request to /model/name/ with model input as body/query args and no Model-Retrain, 

110 in headers will return predictions. 

111 Request to /model/name/data lets people add rows the test data. 

112 ''' 

113 # load model object - if it doesn't exist, send a response asking to create the model 

114 try: 

115 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load) 

116 except EnvironmentError: # noqa 

117 # Log error 

118 self.write({'Error': 'Please Send PUT Request, model does not exist'}) 

119 raise EnvironmentError # noqa 

120 if self.get_data_flag(): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true

121 file_kwargs = self.listify(model.input + [model.output]) 

122 gramex.data.insert(model.url, args=file_kwargs) 

123 else: 

124 # If /data/ is not path_args[1] then post is sending a predict request 

125 if self._train(model): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 return 

127 self._predict(model) 

128 

129 def delete(self, *path_args): 

130 ''' 

131 Request to /model/name/ will delete the trained model. 

132 Request to /model/name/data needs id and will delete rows from the training data. 

133 ''' 

134 if self.get_data_flag(): 134 ↛ 135line 134 didn't jump to line 135, because the condition on line 134 was never true

135 file_kwargs = self.listify(['id']) 

136 try: 

137 model = gramex.cache.open(self.pickle_file_path, gramex.ml.load) 

138 except EnvironmentError: # noqa 

139 self.write( 

140 {'Error': 'Please Send PUT Request, model does not exist'}) 

141 raise EnvironmentError # noqa 

142 gramex.data.delete(model.url, args=file_kwargs, id=file_kwargs['id']) 

143 return 

144 if os.path.exists(self.pickle_file_path): 144 ↛ exitline 144 didn't return from function 'delete', because the condition on line 144 was never false

145 os.unlink(self.pickle_file_path) 

146 

147 def _train(self, model): 

148 ''' Looks for Model-Retrain in Request Headers, 

149 trains a model and pickles it. 

150 ''' 

151 # Update model parameters 

152 model.update_params(self.request_body) 

153 if 'Model-Retrain' in self.request.headers: 

154 # Pass non model kwargs to gramex.data.filter 

155 try: 

156 data = gramex.data.filter( 

157 model.url, 

158 args=self.listify(['engine', 'url', 'ext', 'table', 'query', 'id'])) 

159 except AttributeError: 

160 raise AttributeError('Model does not have a url') 

161 # Train the model. 

162 model.train(data) 

163 model.trained = True 

164 model.save(self.pickle_file_path) 

165 return True 

166 

167 def listify(self, checklst): 

168 ''' Some functions in data.py expect list values, so creates them. 

169 checklst is list-like which contains the selected values to be returned. 

170 ''' 

171 return { 

172 k: [v] if not isinstance(v, list) else v 

173 for k, v in self.request_body.items() 

174 if k in checklst 

175 }