001package com.facebook.swift.service;
002
003import com.facebook.nifty.core.RequestContext;
004import com.facebook.swift.codec.ThriftCodecManager;
005import com.facebook.swift.service.ContextChain;
006import com.facebook.swift.service.ThriftEventHandler;
007import com.facebook.swift.service.ThriftMethodProcessor;
008import com.facebook.swift.service.ThriftServiceProcessor;
009import com.facebook.swift.service.metadata.ThriftMethodMetadata;
010import com.facebook.swift.service.metadata.ThriftServiceMetadata;
011import com.google.common.base.Preconditions;
012import com.google.common.collect.ImmutableList;
013import com.google.common.collect.ImmutableMap;
014import com.google.common.util.concurrent.FutureCallback;
015import com.google.common.util.concurrent.Futures;
016import com.google.common.util.concurrent.ListenableFuture;
017import com.google.common.util.concurrent.SettableFuture;
018import org.apache.thrift.TApplicationException;
019import org.apache.thrift.TException;
020import org.apache.thrift.protocol.TMessage;
021import org.apache.thrift.protocol.TMessageType;
022import org.apache.thrift.protocol.TProtocol;
023import org.apache.thrift.protocol.TProtocolUtil;
024import org.apache.thrift.protocol.TType;
025import org.slf4j.Logger;
026import org.slf4j.LoggerFactory;
027
028import java.util.List;
029import java.util.Map;
030
031import javax.annotation.concurrent.ThreadSafe;
032
033import static com.google.common.collect.Maps.newHashMap;
034import static org.apache.thrift.TApplicationException.INVALID_MESSAGE_TYPE;
035import static org.apache.thrift.TApplicationException.UNKNOWN_METHOD;
036
037/**
038 * {@link ThriftServiceProcessor}子类,代码基本都是从父类复制过来,
039 * 只是在构造方法{@link #ThriftServiceProcessorCustom(ThriftCodecManager, List, List)}中
040 * 用{@link ThriftMethodProcessorCustom}替换了{@link ThriftMethodProcessor}
041 * @author guyadong
042 *
043 */
044@ThreadSafe
045public class ThriftServiceProcessorCustom extends ThriftServiceProcessor
046{
047    private static final Logger LOG = LoggerFactory.getLogger(ThriftServiceProcessorCustom.class);
048
049    private final Map<String, ThriftMethodProcessor> methods;
050    private final List<ThriftEventHandler> eventHandlers;
051
052    /**
053     * @param eventHandlers event handlers to attach to services
054     * @param services the services to expose; services must be thread safe
055     */
056    public ThriftServiceProcessorCustom(ThriftCodecManager codecManager, List<? extends ThriftEventHandler> eventHandlers, Object... services)
057    {
058        this(codecManager, eventHandlers, ImmutableList.copyOf(services));
059    }
060
061    public ThriftServiceProcessorCustom(ThriftCodecManager codecManager, List<? extends ThriftEventHandler> eventHandlers, List<?> services)
062    {
063        super(codecManager, eventHandlers, services);
064        Preconditions.checkArgument(null != codecManager, "codecManager is null");
065        Preconditions.checkArgument(null != services, "services is null");
066        Preconditions.checkArgument(!services.isEmpty(), "services is empty");
067
068        Map<String, ThriftMethodProcessor> processorMap = newHashMap();
069        for (Object service : services) {
070            ThriftServiceMetadata serviceMetadata = new ThriftServiceMetadata(service.getClass(), codecManager.getCatalog());
071            for (ThriftMethodMetadata methodMetadata : serviceMetadata.getMethods().values()) {
072                String methodName = methodMetadata.getName();
073                // 替换ThriftMethodProcessor
074                ThriftMethodProcessor methodProcessor = new ThriftMethodProcessorCustom(service, serviceMetadata.getName(), methodMetadata, codecManager);
075                if (processorMap.containsKey(methodName)) {
076                    throw new IllegalArgumentException("Multiple @ThriftMethod-annotated methods named '" + methodName + "' found in the given services");
077                }
078                processorMap.put(methodName, methodProcessor);
079            }
080        }
081        methods = ImmutableMap.copyOf(processorMap);
082        this.eventHandlers = ImmutableList.copyOf(eventHandlers);
083    }
084    @Override
085    public Map<String, ThriftMethodProcessor> getMethods()
086    {
087        return methods;
088    }
089
090    @Override
091    public ListenableFuture<Boolean> process(final TProtocol in, TProtocol out, RequestContext requestContext)
092            throws TException
093    {
094        try {
095            final SettableFuture<Boolean> resultFuture = SettableFuture.create();
096            TMessage message = in.readMessageBegin();
097            String methodName = message.name;
098            int sequenceId = message.seqid;
099
100            // lookup method
101            ThriftMethodProcessor method = methods.get(methodName);
102            if (method == null) {
103                TProtocolUtil.skip(in, TType.STRUCT);
104                writeApplicationException(out, methodName, sequenceId, UNKNOWN_METHOD, "Invalid method name: '" + methodName + "'", null);
105                return Futures.immediateFuture(true);
106            }
107
108            switch (message.type) {
109                case TMessageType.CALL:
110                case TMessageType.ONEWAY:
111                    // Ideally we'd check the message type here to make the presence/absence of
112                    // the "oneway" keyword annotating the method matches the message type.
113                    // Unfortunately most clients send both one-way and two-way messages as CALL
114                    // message type instead of using ONEWAY message type, and servers ignore the
115                    // difference.
116                    break;
117
118                default:
119                    TProtocolUtil.skip(in, TType.STRUCT);
120                    writeApplicationException(out, methodName, sequenceId, INVALID_MESSAGE_TYPE, "Received invalid message type " + message.type + " from client", null);
121                    return Futures.immediateFuture(true);
122            }
123
124            // invoke method
125            final ContextChain context = new ContextChain(eventHandlers, method.getQualifiedName(), requestContext);
126            ListenableFuture<Boolean> processResult = method.process(in, out, sequenceId, context);
127
128            Futures.addCallback(
129                    processResult,
130                    new FutureCallback<Boolean>()
131                    {
132                        @Override
133                        public void onSuccess(Boolean result)
134                        {
135                            context.done();
136                            @SuppressWarnings("unused")
137                                                        boolean b = resultFuture.set(result);
138                        }
139
140                        @Override
141                        public void onFailure(Throwable t)
142                        {
143                            context.done();
144                            @SuppressWarnings("unused")
145                                                        boolean b = resultFuture.setException(t);
146                        }
147                    });
148
149            return resultFuture;
150        }
151        catch (Exception e) {
152            return Futures.immediateFailedFuture(e);
153        }
154    }
155
156    public static TApplicationException writeApplicationException(
157            TProtocol outputProtocol,
158            String methodName,
159            int sequenceId,
160            int errorCode,
161            String errorMessage,
162            Throwable cause)
163            throws TException
164    {
165        // unexpected exception
166        TApplicationException applicationException = new TApplicationException(errorCode, errorMessage);
167        if (cause != null) {
168            applicationException.initCause(cause);
169        }
170
171        LOG.error(errorMessage, applicationException);
172
173        // Application exceptions are sent to client, and the connection can be reused
174        outputProtocol.writeMessageBegin(new TMessage(methodName, TMessageType.EXCEPTION, sequenceId));
175        applicationException.write(outputProtocol);
176        outputProtocol.writeMessageEnd();
177        outputProtocol.getTransport().flush();
178
179        return applicationException;
180    }
181}