001package com.facebook.swift.service; 002 003import com.facebook.nifty.core.RequestContext; 004import com.facebook.nifty.core.RequestContexts; 005import com.facebook.swift.codec.ThriftCodec; 006import com.facebook.swift.codec.ThriftCodecManager; 007import com.facebook.swift.codec.internal.TProtocolReader; 008import com.facebook.swift.codec.internal.TProtocolWriter; 009import com.facebook.swift.codec.metadata.ThriftFieldMetadata; 010import com.facebook.swift.codec.metadata.ThriftType; 011import com.facebook.swift.service.ContextChain; 012import com.facebook.swift.service.ThriftServiceProcessor; 013import com.facebook.swift.service.metadata.ThriftMethodMetadata; 014import com.google.common.base.Defaults; 015import com.google.common.collect.ImmutableList; 016import com.google.common.collect.ImmutableMap; 017import com.google.common.collect.ImmutableMap.Builder; 018import com.google.common.primitives.Primitives; 019import com.google.common.reflect.TypeToken; 020import com.google.common.util.concurrent.FutureCallback; 021import com.google.common.util.concurrent.Futures; 022import com.google.common.util.concurrent.ListenableFuture; 023import com.google.common.util.concurrent.SettableFuture; 024import org.apache.thrift.TApplicationException; 025import org.apache.thrift.protocol.TMessage; 026import org.apache.thrift.protocol.TMessageType; 027import org.apache.thrift.protocol.TProtocol; 028import org.apache.thrift.protocol.TProtocolException; 029import org.weakref.jmx.Managed; 030import com.google.common.base.Preconditions; 031 032import javax.annotation.concurrent.ThreadSafe; 033import java.lang.reflect.InvocationTargetException; 034import java.lang.reflect.Method; 035import java.lang.reflect.Type; 036import java.util.Map; 037 038import static org.apache.thrift.TApplicationException.INTERNAL_ERROR; 039 040/** 041 * 定制{@link ThriftMethodProcessor},收到的primitiveå°è£…类型(Boolean,Integer...)如果是{@code null}ä¸ä¼šè¢«è½¬æˆdefault 0. 042 * 通过调用 {@link #setUseDefaultValueIfPrimitiveWrap(boolean)} 方法设置为{@code true}å¯ä»¥ç¦ç”¨è¯¥ç‰¹æ€§ 043 * @author guyadong 044 * 045 */ 046@ThreadSafe 047public class ThriftMethodProcessorCustom extends ThriftMethodProcessor 048{ 049 private final String name; 050 private final String serviceName; 051 private final String qualifiedName; 052 private final Object service; 053 private final Method method; 054 private final String resultStructName; 055 private final boolean oneway; 056 private final ImmutableList<ThriftFieldMetadata> parameters; 057 private final Map<Short, ThriftCodec<?>> parameterCodecs; 058 private final Map<Short, Short> thriftParameterIdToJavaArgumentListPositionMap; 059 private final ThriftCodec<Object> successCodec; 060 private final Map<Class<?>, ExceptionProcessor> exceptionCodecs; 061 /** 062 * 全局开关<br> 063 * 为{@code true}时对于primitive wrapç±»åž‹å‚æ•°,å°†{@code null}转为default value 064 */ 065 private static Boolean useDefaultValueIfPrimitiveWrap= null; 066 @SuppressWarnings("unchecked") 067 public ThriftMethodProcessorCustom( 068 Object service, 069 String serviceName, 070 ThriftMethodMetadata methodMetadata, 071 ThriftCodecManager codecManager 072 ) 073 { 074 super(service, serviceName, methodMetadata, codecManager); 075 this.service = service; 076 this.serviceName = serviceName; 077 078 name = methodMetadata.getName(); 079 qualifiedName = serviceName + "." + name; 080 resultStructName = name + "_result"; 081 082 method = methodMetadata.getMethod(); 083 oneway = methodMetadata.getOneway(); 084 085 parameters = ImmutableList.copyOf(methodMetadata.getParameters()); 086 087 ImmutableMap.Builder<Short, ThriftCodec<?>> builder = ImmutableMap.builder(); 088 for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) { 089 @SuppressWarnings("unused") 090 Builder<Short, ThriftCodec<?>> p = builder.put(fieldMetadata.getId(), codecManager.getCodec(fieldMetadata.getThriftType())); 091 } 092 parameterCodecs = builder.build(); 093 094 // Build a mapping from thrift parameter ID to a position in the formal argument list 095 ImmutableMap.Builder<Short, Short> parameterOrderingBuilder = ImmutableMap.builder(); 096 short javaArgumentPosition = 0; 097 for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) { 098 @SuppressWarnings("unused") 099 Builder<Short, Short> p = parameterOrderingBuilder.put(fieldMetadata.getId(), javaArgumentPosition++); 100 } 101 thriftParameterIdToJavaArgumentListPositionMap = parameterOrderingBuilder.build(); 102 103 ImmutableMap.Builder<Class<?>, ExceptionProcessor> exceptions = ImmutableMap.builder(); 104 for (Map.Entry<Short, ThriftType> entry : methodMetadata.getExceptions().entrySet()) { 105 Class<?> type = TypeToken.of(entry.getValue().getJavaType()).getRawType(); 106 ExceptionProcessor processor = new ExceptionProcessor(entry.getKey(), codecManager.getCodec(entry.getValue())); 107 @SuppressWarnings("unused") 108 Builder<Class<?>, ExceptionProcessor> p = exceptions.put(type, processor); 109 } 110 exceptionCodecs = exceptions.build(); 111 112 successCodec = (ThriftCodec<Object>) codecManager.getCodec(methodMetadata.getReturnType()); 113 } 114 115 @Managed 116 @Override 117 public String getName() 118 { 119 return name; 120 } 121 @Override 122 public Class<?> getServiceClass() { 123 return service.getClass(); 124 } 125 @Override 126 public String getServiceName() 127 { 128 return serviceName; 129 } 130 @Override 131 public String getQualifiedName() 132 { 133 return qualifiedName; 134 } 135 @Override 136 public ListenableFuture<Boolean> process(TProtocol in, final TProtocol out, final int sequenceId, final ContextChain contextChain) 137 throws Exception 138 { 139 // read args 140 contextChain.preRead(); 141 Object[] args = readArguments(in); 142 contextChain.postRead(args); 143 final RequestContext requestContext = RequestContexts.getCurrentContext(); 144 145 in.readMessageEnd(); 146 147 // invoke method 148 final ListenableFuture<?> invokeFuture = invokeMethod(args); 149 final SettableFuture<Boolean> resultFuture = SettableFuture.create(); 150 151 Futures.addCallback(invokeFuture, new FutureCallback<Object>() 152 { 153 @Override 154 public void onSuccess(Object result) 155 { 156 if (oneway) { 157 @SuppressWarnings("unused") 158 boolean s = resultFuture.set(true); 159 } 160 else { 161 RequestContext oldRequestContext = RequestContexts.getCurrentContext(); 162 RequestContexts.setCurrentContext(requestContext); 163 164 // write success reply 165 try { 166 contextChain.preWrite(result); 167 168 writeResponse(out, 169 sequenceId, 170 TMessageType.REPLY, 171 "success", 172 (short) 0, 173 successCodec, 174 result); 175 176 contextChain.postWrite(result); 177 178 @SuppressWarnings("unused") 179 boolean s = resultFuture.set(true); 180 } 181 catch (Exception e) { 182 // An exception occurred trying to serialize a return value onto the output protocol 183 @SuppressWarnings("unused") 184 boolean b = resultFuture.setException(e); 185 } 186 finally { 187 RequestContexts.setCurrentContext(oldRequestContext); 188 } 189 } 190 } 191 192 @Override 193 public void onFailure(Throwable t) 194 { 195 RequestContext oldRequestContext = RequestContexts.getCurrentContext(); 196 RequestContexts.setCurrentContext(requestContext); 197 198 try { 199 contextChain.preWriteException(t); 200 if (!oneway) { 201 ExceptionProcessor exceptionCodec = exceptionCodecs.get(t.getClass()); 202 if (exceptionCodec != null) { 203 // write expected exception response 204 writeResponse( 205 out, 206 sequenceId, 207 TMessageType.REPLY, 208 "exception", 209 exceptionCodec.getId(), 210 exceptionCodec.getCodec(), 211 t); 212 contextChain.postWriteException(t); 213 } else { 214 // unexpected exception 215 TApplicationException applicationException = 216 ThriftServiceProcessor.writeApplicationException( 217 out, 218 method.getName(), 219 sequenceId, 220 INTERNAL_ERROR, 221 "Internal error processing " + method.getName(), 222 t); 223 224 contextChain.postWriteException(applicationException); 225 } 226 } 227 228 @SuppressWarnings("unused") 229 boolean b = resultFuture.set(true); 230 } 231 catch (Exception e) { 232 // An exception occurred trying to serialize an exception onto the output protocol 233 @SuppressWarnings("unused") 234 boolean b = resultFuture.setException(e); 235 } 236 finally { 237 RequestContexts.setCurrentContext(oldRequestContext); 238 } 239 } 240 }); 241 242 return resultFuture; 243 } 244 245 private ListenableFuture<?> invokeMethod(Object[] args) 246 { 247 try { 248 Object response = method.invoke(service, args); 249 if (response instanceof ListenableFuture) { 250 return (ListenableFuture<?>) response; 251 } 252 return Futures.immediateFuture(response); 253 } 254 catch (IllegalAccessException | IllegalArgumentException e) { 255 // These really should never happen, since the method metadata should have prevented it 256 return Futures.immediateFailedFuture(e); 257 } 258 catch (InvocationTargetException e) { 259 Throwable cause = e.getCause(); 260 if (cause != null) { 261 return Futures.immediateFailedFuture(cause); 262 } 263 264 return Futures.immediateFailedFuture(e); 265 } 266 } 267 268 private Object[] readArguments(TProtocol in) 269 throws Exception 270 { 271 try { 272 int numArgs = method.getParameterTypes().length; 273 Object[] args = new Object[numArgs]; 274 TProtocolReader reader = new TProtocolReader(in); 275 276 // Map incoming arguments from the ID passed in on the wire to the position in the 277 // java argument list we expect to see a parameter with that ID. 278 reader.readStructBegin(); 279 while (reader.nextField()) { 280 short fieldId = reader.getFieldId(); 281 282 ThriftCodec<?> codec = parameterCodecs.get(fieldId); 283 if (codec == null) { 284 // unknown field 285 reader.skipFieldData(); 286 } 287 else { 288 // Map the incoming arguments to an array of arguments ordered as the java 289 // code for the handler method expects to see them 290 args[thriftParameterIdToJavaArgumentListPositionMap.get(fieldId)] = reader.readField(codec); 291 } 292 } 293 reader.readStructEnd(); 294 295 // Walk through our list of expected parameters and if no incoming parameters were 296 // mapped to a particular expected parameter, fill the expected parameter slow with 297 // the default for the parameter type. 298 // æ ¹æ®å…¨å±€å¼€å…³åˆ¤æ–是å¦å¯¹null替æ¢ä¸ºdefault value 299 if(Boolean.TRUE.equals(useDefaultValueIfPrimitiveWrap)){ 300 int argumentPosition = 0; 301 for (ThriftFieldMetadata argument : parameters) { 302 if (args[argumentPosition] == null) { 303 Type argumentType = argument.getThriftType().getJavaType(); 304 305 if (argumentType instanceof Class) { 306 Class<?> argumentClass = (Class<?>) argumentType; 307 argumentClass = Primitives.unwrap(argumentClass); 308 args[argumentPosition] = Defaults.defaultValue(argumentClass); 309 } 310 } 311 argumentPosition++; 312 } 313 } 314 return args; 315 } 316 catch (TProtocolException e) { 317 // TProtocolException is the only recoverable exception 318 // Other exceptions may have left the input stream in corrupted state so we must 319 // tear down the socket. 320 throw new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage()); 321 } 322 } 323 324 private <T> void writeResponse(TProtocol out, 325 int sequenceId, 326 byte responseType, 327 String responseFieldName, 328 short responseFieldId, 329 ThriftCodec<T> responseCodec, 330 T result) throws Exception 331 { 332 out.writeMessageBegin(new TMessage(name, responseType, sequenceId)); 333 334 TProtocolWriter writer = new TProtocolWriter(out); 335 writer.writeStructBegin(resultStructName); 336 writer.writeField(responseFieldName, (short) responseFieldId, responseCodec, result); 337 writer.writeStructEnd(); 338 339 out.writeMessageEnd(); 340 out.getTransport().flush(); 341 } 342 343 private static final class ExceptionProcessor 344 { 345 private final short id; 346 private final ThriftCodec<Object> codec; 347 348 @SuppressWarnings("unchecked") 349 private ExceptionProcessor(short id, ThriftCodec<?> coded) 350 { 351 this.id = id; 352 this.codec = (ThriftCodec<Object>) coded; 353 } 354 355 public short getId() 356 { 357 return id; 358 } 359 360 public ThriftCodec<Object> getCodec() 361 { 362 return codec; 363 } 364 } 365 366 /** 367 * 设置对于primitive wrap类型是å¦å°†{@code null}转为default value,默认为{@code false} <br> 368 * 该方法åªèƒ½è¢«è°ƒç”¨ä¸€æ¬¡ 369 * @param useDefaultValueIfPrimitiveWrap 370 * @throws IllegalStateException 该方法已ç»è¢«è°ƒç”¨è¿‡ 371 */ 372 public static synchronized void setUseDefaultValueIfPrimitiveWrap(boolean useDefaultValueIfPrimitiveWrap) { 373 Preconditions.checkState(null == ThriftMethodProcessorCustom.useDefaultValueIfPrimitiveWrap, 374 "useDefaultValueIfPrimitiveWrap can be initialized only once"); 375 ThriftMethodProcessorCustom.useDefaultValueIfPrimitiveWrap = useDefaultValueIfPrimitiveWrap; 376 } 377}