001package com.facebook.swift.service;
002
003import com.facebook.nifty.core.RequestContext;
004import com.facebook.nifty.core.RequestContexts;
005import com.facebook.swift.codec.ThriftCodec;
006import com.facebook.swift.codec.ThriftCodecManager;
007import com.facebook.swift.codec.internal.TProtocolReader;
008import com.facebook.swift.codec.internal.TProtocolWriter;
009import com.facebook.swift.codec.metadata.ThriftFieldMetadata;
010import com.facebook.swift.codec.metadata.ThriftType;
011import com.facebook.swift.service.ContextChain;
012import com.facebook.swift.service.ThriftServiceProcessor;
013import com.facebook.swift.service.metadata.ThriftMethodMetadata;
014import com.google.common.base.Defaults;
015import com.google.common.collect.ImmutableList;
016import com.google.common.collect.ImmutableMap;
017import com.google.common.collect.ImmutableMap.Builder;
018import com.google.common.primitives.Primitives;
019import com.google.common.reflect.TypeToken;
020import com.google.common.util.concurrent.FutureCallback;
021import com.google.common.util.concurrent.Futures;
022import com.google.common.util.concurrent.ListenableFuture;
023import com.google.common.util.concurrent.SettableFuture;
024import org.apache.thrift.TApplicationException;
025import org.apache.thrift.protocol.TMessage;
026import org.apache.thrift.protocol.TMessageType;
027import org.apache.thrift.protocol.TProtocol;
028import org.apache.thrift.protocol.TProtocolException;
029import org.weakref.jmx.Managed;
030import com.google.common.base.Preconditions;
031
032import javax.annotation.concurrent.ThreadSafe;
033import java.lang.reflect.InvocationTargetException;
034import java.lang.reflect.Method;
035import java.lang.reflect.Type;
036import java.util.Map;
037
038import static org.apache.thrift.TApplicationException.INTERNAL_ERROR;
039
040/**
041 * 定制{@link ThriftMethodProcessor},收到的primitive封装类型(Boolean,Integer...)如果是{@code null}不会被转成default 0.
042 * 通过调用 {@link #setUseDefaultValueIfPrimitiveWrap(boolean)} 方法设置为{@code true}可以禁用该特性
043 * @author guyadong
044 *
045 */
046@ThreadSafe
047public class ThriftMethodProcessorCustom extends ThriftMethodProcessor
048{
049    private final String name;
050    private final String serviceName;
051    private final String qualifiedName;
052    private final Object service;
053    private final Method method;
054    private final String resultStructName;
055    private final boolean oneway;
056    private final ImmutableList<ThriftFieldMetadata> parameters;
057    private final Map<Short, ThriftCodec<?>> parameterCodecs;
058    private final Map<Short, Short> thriftParameterIdToJavaArgumentListPositionMap;
059    private final ThriftCodec<Object> successCodec;
060    private final Map<Class<?>, ExceptionProcessor> exceptionCodecs;
061    /** 
062     * 全局开关<br>
063     * 为{@code true}时对于primitive wrap类型参数,将{@code null}转为default value 
064     */
065    private static Boolean useDefaultValueIfPrimitiveWrap= null;
066    @SuppressWarnings("unchecked")
067        public ThriftMethodProcessorCustom(
068            Object service,
069            String serviceName,
070            ThriftMethodMetadata methodMetadata,
071            ThriftCodecManager codecManager
072    )
073    {
074        super(service, serviceName, methodMetadata, codecManager);
075        this.service = service;
076        this.serviceName = serviceName;
077
078        name = methodMetadata.getName();
079        qualifiedName = serviceName + "." + name;
080        resultStructName = name + "_result";
081
082        method = methodMetadata.getMethod();
083        oneway = methodMetadata.getOneway();
084
085        parameters = ImmutableList.copyOf(methodMetadata.getParameters());
086
087        ImmutableMap.Builder<Short, ThriftCodec<?>> builder = ImmutableMap.builder();
088        for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) {
089            @SuppressWarnings("unused")
090                        Builder<Short, ThriftCodec<?>> p = builder.put(fieldMetadata.getId(), codecManager.getCodec(fieldMetadata.getThriftType()));
091        }
092        parameterCodecs = builder.build();
093
094        // Build a mapping from thrift parameter ID to a position in the formal argument list
095        ImmutableMap.Builder<Short, Short> parameterOrderingBuilder = ImmutableMap.builder();
096        short javaArgumentPosition = 0;
097        for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) {
098            @SuppressWarnings("unused")
099                        Builder<Short, Short> p = parameterOrderingBuilder.put(fieldMetadata.getId(), javaArgumentPosition++);
100        }
101        thriftParameterIdToJavaArgumentListPositionMap = parameterOrderingBuilder.build();
102
103        ImmutableMap.Builder<Class<?>, ExceptionProcessor> exceptions = ImmutableMap.builder();
104        for (Map.Entry<Short, ThriftType> entry : methodMetadata.getExceptions().entrySet()) {
105            Class<?> type = TypeToken.of(entry.getValue().getJavaType()).getRawType();
106            ExceptionProcessor processor = new ExceptionProcessor(entry.getKey(), codecManager.getCodec(entry.getValue()));
107            @SuppressWarnings("unused")
108                        Builder<Class<?>, ExceptionProcessor> p = exceptions.put(type, processor);
109        }
110        exceptionCodecs = exceptions.build();
111
112        successCodec = (ThriftCodec<Object>) codecManager.getCodec(methodMetadata.getReturnType());
113    }
114
115    @Managed
116    @Override
117    public String getName()
118    {
119        return name;
120    }
121    @Override
122    public Class<?> getServiceClass() {
123        return service.getClass();
124    }
125    @Override
126    public String getServiceName()
127    {
128        return serviceName;
129    }
130    @Override
131    public String getQualifiedName()
132    {
133        return qualifiedName;
134    }
135    @Override
136    public ListenableFuture<Boolean> process(TProtocol in, final TProtocol out, final int sequenceId, final ContextChain contextChain)
137            throws Exception
138    {
139        // read args
140        contextChain.preRead();
141        Object[] args = readArguments(in);
142        contextChain.postRead(args);
143        final RequestContext requestContext = RequestContexts.getCurrentContext();
144
145        in.readMessageEnd();
146
147        // invoke method
148        final ListenableFuture<?> invokeFuture = invokeMethod(args);
149        final SettableFuture<Boolean> resultFuture = SettableFuture.create();
150
151        Futures.addCallback(invokeFuture, new FutureCallback<Object>()
152        {
153            @Override
154            public void onSuccess(Object result)
155            {
156                if (oneway) {
157                    @SuppressWarnings("unused")
158                                        boolean s = resultFuture.set(true);
159                }
160                else {
161                    RequestContext oldRequestContext = RequestContexts.getCurrentContext();
162                    RequestContexts.setCurrentContext(requestContext);
163
164                    // write success reply
165                    try {
166                        contextChain.preWrite(result);
167
168                        writeResponse(out,
169                                      sequenceId,
170                                      TMessageType.REPLY,
171                                      "success",
172                                      (short) 0,
173                                      successCodec,
174                                      result);
175
176                        contextChain.postWrite(result);
177
178                        @SuppressWarnings("unused")
179                                                boolean s = resultFuture.set(true);
180                    }
181                    catch (Exception e) {
182                        // An exception occurred trying to serialize a return value onto the output protocol
183                        @SuppressWarnings("unused")
184                                                boolean b = resultFuture.setException(e);
185                    }
186                    finally {
187                        RequestContexts.setCurrentContext(oldRequestContext);
188                    }
189                }
190            }
191
192            @Override
193            public void onFailure(Throwable t)
194            {
195                RequestContext oldRequestContext = RequestContexts.getCurrentContext();
196                RequestContexts.setCurrentContext(requestContext);
197
198                try {
199                    contextChain.preWriteException(t);
200                    if (!oneway) {
201                        ExceptionProcessor exceptionCodec = exceptionCodecs.get(t.getClass());
202                        if (exceptionCodec != null) {
203                            // write expected exception response
204                            writeResponse(
205                                    out,
206                                    sequenceId,
207                                    TMessageType.REPLY,
208                                    "exception",
209                                    exceptionCodec.getId(),
210                                    exceptionCodec.getCodec(),
211                                    t);
212                            contextChain.postWriteException(t);
213                        } else {
214                            // unexpected exception
215                            TApplicationException applicationException =
216                                    ThriftServiceProcessor.writeApplicationException(
217                                            out,
218                                            method.getName(),
219                                            sequenceId,
220                                            INTERNAL_ERROR,
221                                            "Internal error processing " + method.getName(),
222                                            t);
223
224                            contextChain.postWriteException(applicationException);
225                        }
226                    }
227
228                    @SuppressWarnings("unused")
229                                        boolean b = resultFuture.set(true);
230                }
231                catch (Exception e) {
232                    // An exception occurred trying to serialize an exception onto the output protocol
233                    @SuppressWarnings("unused")
234                                        boolean b = resultFuture.setException(e);
235                }
236                finally {
237                    RequestContexts.setCurrentContext(oldRequestContext);
238                }
239            }
240        });
241
242        return resultFuture;
243    }
244
245    private ListenableFuture<?> invokeMethod(Object[] args)
246    {
247        try {
248            Object response = method.invoke(service, args);
249            if (response instanceof ListenableFuture) {
250                return (ListenableFuture<?>) response;
251            }
252            return Futures.immediateFuture(response);
253        }
254        catch (IllegalAccessException | IllegalArgumentException e) {
255            // These really should never happen, since the method metadata should have prevented it
256            return Futures.immediateFailedFuture(e);
257        }
258        catch (InvocationTargetException e) {
259            Throwable cause = e.getCause();
260            if (cause != null) {
261                return Futures.immediateFailedFuture(cause);
262            }
263
264            return Futures.immediateFailedFuture(e);
265        }
266    }
267
268    private Object[] readArguments(TProtocol in)
269            throws Exception
270    {
271        try {
272            int numArgs = method.getParameterTypes().length;
273            Object[] args = new Object[numArgs];
274            TProtocolReader reader = new TProtocolReader(in);
275
276            // Map incoming arguments from the ID passed in on the wire to the position in the
277            // java argument list we expect to see a parameter with that ID.
278            reader.readStructBegin();
279            while (reader.nextField()) {
280                short fieldId = reader.getFieldId();
281
282                ThriftCodec<?> codec = parameterCodecs.get(fieldId);
283                if (codec == null) {
284                    // unknown field
285                    reader.skipFieldData();
286                }
287                else {
288                    // Map the incoming arguments to an array of arguments ordered as the java
289                    // code for the handler method expects to see them
290                    args[thriftParameterIdToJavaArgumentListPositionMap.get(fieldId)] = reader.readField(codec);
291                }
292            }
293            reader.readStructEnd();
294
295            // Walk through our list of expected parameters and if no incoming parameters were
296            // mapped to a particular expected parameter, fill the expected parameter slow with
297            // the default for the parameter type.
298            // 根据全局开关判断是否对null替换为default value
299            if(Boolean.TRUE.equals(useDefaultValueIfPrimitiveWrap)){
300                int argumentPosition = 0;
301                for (ThriftFieldMetadata argument : parameters) {
302                        if (args[argumentPosition] == null) {
303                                Type argumentType = argument.getThriftType().getJavaType();
304
305                                if (argumentType instanceof Class) {
306                                        Class<?> argumentClass = (Class<?>) argumentType;
307                                        argumentClass = Primitives.unwrap(argumentClass);
308                                        args[argumentPosition] = Defaults.defaultValue(argumentClass);
309                                }
310                        }
311                        argumentPosition++;
312                }
313            }
314            return args;
315        }
316        catch (TProtocolException e) {
317            // TProtocolException is the only recoverable exception
318            // Other exceptions may have left the input stream in corrupted state so we must
319            // tear down the socket.
320            throw new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage());
321        }
322    }
323
324    private <T> void writeResponse(TProtocol out,
325                                   int sequenceId,
326                                   byte responseType,
327                                   String responseFieldName,
328                                   short responseFieldId,
329                                   ThriftCodec<T> responseCodec,
330                                   T result) throws Exception
331    {
332        out.writeMessageBegin(new TMessage(name, responseType, sequenceId));
333
334        TProtocolWriter writer = new TProtocolWriter(out);
335        writer.writeStructBegin(resultStructName);
336        writer.writeField(responseFieldName, (short) responseFieldId, responseCodec, result);
337        writer.writeStructEnd();
338
339        out.writeMessageEnd();
340        out.getTransport().flush();
341    }
342
343    private static final class ExceptionProcessor
344    {
345        private final short id;
346        private final ThriftCodec<Object> codec;
347
348        @SuppressWarnings("unchecked")
349                private ExceptionProcessor(short id, ThriftCodec<?> coded)
350        {
351            this.id = id;
352            this.codec = (ThriftCodec<Object>) coded;
353        }
354
355        public short getId()
356        {
357            return id;
358        }
359
360        public ThriftCodec<Object> getCodec()
361        {
362            return codec;
363        }
364    }
365
366        /**
367         * 设置对于primitive wrap类型是否将{@code null}转为default value,默认为{@code false} <br>
368         * 该方法只能被调用一次
369         * @param useDefaultValueIfPrimitiveWrap
370         * @throws IllegalStateException 该方法已经被调用过
371         */
372        public static synchronized void setUseDefaultValueIfPrimitiveWrap(boolean useDefaultValueIfPrimitiveWrap) {
373                Preconditions.checkState(null == ThriftMethodProcessorCustom.useDefaultValueIfPrimitiveWrap,
374                                "useDefaultValueIfPrimitiveWrap can be initialized only once");
375                ThriftMethodProcessorCustom.useDefaultValueIfPrimitiveWrap = useDefaultValueIfPrimitiveWrap;
376        }
377}