001package com.facebook.swift.service;
002
003import java.lang.reflect.Constructor;
004import java.lang.reflect.Field;
005import java.lang.reflect.Modifier;
006import java.util.List;
007import org.apache.thrift.protocol.TJSONProtocol;
008import org.jboss.netty.channel.ChannelPipeline;
009import org.jboss.netty.channel.ChannelPipelineFactory;
010import org.jboss.netty3.handler.codec.http.cors.Netty3CorsConfig;
011import org.jboss.netty3.handler.codec.http.cors.Netty3CorsConfigBuilder;
012import org.jboss.netty3.handler.codec.http.cors.Netty3CorsHandler;
013import org.slf4j.Logger;
014import org.slf4j.LoggerFactory;
015import com.google.common.base.Throwables;
016
017import com.google.common.base.Joiner;
018import com.google.common.base.Strings;
019
020import static com.google.common.base.Preconditions.*;
021import static org.jboss.netty.handler.codec.http.HttpMethod.*;
022
023import com.facebook.nifty.codec.ThriftFrameCodecFactory;
024import com.facebook.nifty.core.NettyServerTransport;
025import com.facebook.nifty.core.NiftyTimer;
026import com.facebook.nifty.duplex.TDuplexProtocolFactory;
027import com.facebook.swift.codec.ThriftCodecManager;
028import com.facebook.swift.service.ThriftEventHandler;
029import com.facebook.swift.service.ThriftServer;
030import com.facebook.swift.service.ThriftServerConfig;
031import com.facebook.swift.service.ThriftService;
032import com.facebook.swift.service.ThriftServiceProcessor;
033import com.facebook.swift.service.metadata.ThriftServiceMetadata;
034import com.google.common.base.Function;
035import com.google.common.collect.ImmutableList;
036import com.google.common.collect.ImmutableMap;
037import com.google.common.collect.Lists;
038import com.google.common.util.concurrent.AbstractIdleService;
039import com.google.common.util.concurrent.MoreExecutors;
040
041import net.gdface.utils.ReflectionUtils;
042
043/**
044 * 创建thrift服务实例{@link ThriftServer},封装为{@link com.google.common.util.concurrent.Service}
045 * @author guyadong
046 *
047 */
048public class ThriftServerService extends AbstractIdleService{
049    private static final Logger logger = LoggerFactory.getLogger(ThriftServerService.class);
050    public static final String HTTP_TRANSPORT = "http";
051    public static final String JSON_PROTOCOL = "json";
052    /**
053     * 在{@link ThriftServer#DEFAULT_PROTOCOL_FACTORIES}基础上增加'json'支持
054     */
055    public static final ImmutableMap<String,TDuplexProtocolFactory> DEFAULT_PROTOCOL_FACTORIES = 
056                ImmutableMap.<String, TDuplexProtocolFactory>builder()
057                        .putAll(ThriftServer.DEFAULT_PROTOCOL_FACTORIES)
058                        .put(JSON_PROTOCOL, TDuplexProtocolFactory.fromSingleFactory(new TJSONProtocol.Factory()))
059                        .build();
060    /**
061     * 在{@link ThriftServer#DEFAULT_FRAME_CODEC_FACTORIES}基础上增加'http'支持
062     */
063    public static final ImmutableMap<String,ThriftFrameCodecFactory> DEFAULT_FRAME_CODEC_FACTORIES = 
064                ImmutableMap.<String, ThriftFrameCodecFactory>builder()
065                        .putAll(ThriftServer.DEFAULT_FRAME_CODEC_FACTORIES)
066                        .put(HTTP_TRANSPORT, (ThriftFrameCodecFactory) new ThriftHttpCodecFactory())
067                        .build();
068
069        public static class Builder {
070                private List<?> services = ImmutableList.of();
071                private ThriftServerConfig thriftServerConfig= new ThriftServerConfig();
072                private List<ThriftEventHandler> eventHandlers = ImmutableList.of();
073                private Builder() {
074                }
075
076                public Builder withServices(Object... services) {
077                        return withServices(ImmutableList.copyOf(services));
078                }
079
080                public Builder withServices(List<?> services) {
081                        this.services = checkNotNull(services);
082                        return this;
083                }
084                public Builder setEventHandlers(List<ThriftEventHandler> eventHandlers){
085                        this.eventHandlers = checkNotNull(eventHandlers);
086                        return this;
087                }
088                public Builder setEventHandlers(ThriftEventHandler...eventHandlers){
089                        return setEventHandlers(ImmutableList.copyOf(eventHandlers));
090                }
091                /**
092                 * 设置服务端口
093                 * @param servicePort
094                 * @return
095                 * @see ThriftServerConfig#setPort(int)
096                 */
097                public Builder setServerPort(int servicePort) {
098                        this.thriftServerConfig.setPort(servicePort);
099                        return this;
100                }
101
102                /**
103                 * 设置服务器配置参数对象
104                 * @param thriftServerConfig
105                 * @return
106                 */
107                public Builder setThriftServerConfig(ThriftServerConfig thriftServerConfig) {
108                        this.thriftServerConfig = checkNotNull(thriftServerConfig,"thriftServerConfig is null");
109                        return this;
110                }
111
112                /**
113                 * 根据参数构造 {@link ThriftServerService}实例
114                 * @return
115                 */
116                public ThriftServerService build() {
117                        return new ThriftServerService(services, eventHandlers, thriftServerConfig);
118                }
119                /**
120                 * 根据参数构造 {@link ThriftServerService}子类实例
121                 * @param subServiceClass
122                 * @return
123                 */
124                public <T extends ThriftServerService> T build(Class<T> subServiceClass) {
125                        try {
126                                Constructor<T> constructor= checkNotNull(subServiceClass,"subServiceClass is null")
127                                                .getDeclaredConstructor(List.class,List.class,ThriftServerConfig.class);
128                                return constructor.newInstance(services,eventHandlers,thriftServerConfig);
129                        } catch (Exception e) {
130                                Throwables.throwIfUnchecked(e);
131                                throw new RuntimeException(e);
132                        }
133                }
134        }
135
136        public static final Builder bulider() {
137                return new Builder();
138        }
139
140        protected final ThriftServer thriftServer;
141        protected final ThriftServiceProcessor processor;
142        protected final ThriftServerConfig thriftServerConfig;
143        protected final String serviceName;
144        /**
145         * 构造函数<br>
146         * @param services 服务对象列表
147         * @param eventHandlers 事件侦听器列表
148         * @param thriftServerConfig 服务配置对象
149         * @see ThriftServiceProcessor#ThriftServiceProcessor(ThriftCodecManager, List, List)
150         * @see ThriftServer#ThriftServer(com.facebook.nifty.processor.NiftyProcessor, ThriftServerConfig)
151         */
152        public ThriftServerService(final List<?> services, 
153                        List<ThriftEventHandler> eventHandlers, 
154                        ThriftServerConfig thriftServerConfig) {
155                checkArgument(null != services && !services.isEmpty());
156                this.thriftServerConfig = checkNotNull(thriftServerConfig,"thriftServerConfig is null");
157                int port = this.thriftServerConfig.getPort();
158                checkArgument(port > 0 && port < 65535,  "INVALID service port %d", port);
159
160                this.processor = new ThriftServiceProcessorCustom(
161                                new ThriftCodecManager(), 
162                                checkNotNull(eventHandlers,"eventHandlers is null"),
163                                services);
164                this.thriftServer =  new ThriftServer(processor,
165                                thriftServerConfig,
166                                new NiftyTimer("thrift"),
167                                DEFAULT_FRAME_CODEC_FACTORIES, DEFAULT_PROTOCOL_FACTORIES, 
168                                ThriftServer.DEFAULT_WORKER_EXECUTORS, 
169                                ThriftServer.DEFAULT_SECURITY_FACTORY);
170                addCorsHandlerIfHttp();
171
172                String serviceList = Joiner.on(",").join(Lists.transform(services, new Function<Object,String>(){
173                        @Override
174                        public String apply(Object input) {
175                                return getServiceName(input);
176                        }}));
177                this.serviceName = String.format("%s(T:%s,P:%s)", 
178                                serviceList,thriftServerConfig.getTransportName(),
179                                thriftServerConfig.getProtocolName());
180                // Arrange to stop the server at shutdown
181                Runtime.getRuntime().addShutdownHook(new Thread() {
182                        @Override
183                        public void run() {
184                                shutDown();
185                        }
186                });
187                addListener(new Listener(){
188                        @Override
189                        public void starting() {
190                                logThriftServerConfig(ThriftServerService.this.thriftServerConfig);
191                        }                       
192                }, MoreExecutors.directExecutor());
193        }
194
195
196        /**
197         * 添加CORS Handler和XHR编解码器
198         */
199        protected void addCorsHandlerIfHttp(){
200                if(HTTP_TRANSPORT.equals(thriftServerConfig.getTransportName())){
201                        try {
202                                // 反射获取私有的成员NettyServerTransport
203                                final NettyServerTransport nettyServerTransport = ReflectionUtils.valueOfField(thriftServer, "transport");
204                                // 反射获取私有的成员ChannelPipelineFactory
205                                Field pipelineFactory = NettyServerTransport.class.getDeclaredField("pipelineFactory");
206                                {
207                                        Field modifiersField = Field.class.getDeclaredField("modifiers");
208                                        modifiersField.setAccessible(true); //Field 的 modifiers 是私有的
209                                        modifiersField.setInt(pipelineFactory, pipelineFactory.getModifiers() & ~Modifier.FINAL);
210                                }
211                                pipelineFactory.setAccessible(true);
212                                final ChannelPipelineFactory channelPipelineFactory = (ChannelPipelineFactory) pipelineFactory.get(nettyServerTransport);
213                                final Netty3CorsConfig corsConfig = Netty3CorsConfigBuilder.forAnyOrigin()
214                                        .allowedRequestMethods(POST,GET,OPTIONS)
215                                        .allowedRequestHeaders("Origin","Content-Type","Accept","application","x-requested-with")
216                                        .build();
217                                ChannelPipelineFactory factoryWithCORS = new ChannelPipelineFactory(){
218
219                                        @Override
220                                        public ChannelPipeline getPipeline() throws Exception {
221                                                // 修改 ChannelPipeline,在frameCodec后(顺序)增加CORS handler,XHR编解码器
222                                                ChannelPipeline cp = channelPipelineFactory.getPipeline();
223//                                              cp.remove("idleTimeoutHandler");
224//                                              cp.remove("idleDisconnectHandler");
225//                                              final ThriftServerDef def = ReflectionUtils.valueOfField(nettyServerTransport, "def");
226//                                              final NettyServerConfig nettyServerConfig = ReflectionUtils.valueOfField(nettyServerTransport, "nettyServerConfig");
227//                              if (def.getClientIdleTimeout() != null) {
228//                                  // Add handlers to detect idle client connections and disconnect them
229//                                  cp.addBefore("authHandler","idleTimeoutHandler", 
230//                                              new IdleStateHandler(nettyServerConfig.getTimer(),
231//                                                                                        0,
232//                                                                                        0,
233//                                                                                        0,
234//                                                                                        TimeUnit.MILLISECONDS));
235//                                  cp.addBefore("authHandler","idleDisconnectHandler", new IdleDisconnectHandler());
236//                              }
237                                                cp.addAfter("frameCodec", "thriftServerXHRCodec", new ThriftServerXHRCodec());
238                                                cp.addAfter("frameCodec", "cors", new Netty3CorsHandler(corsConfig));
239                                                return cp;
240                                        }};
241                                // 修改nettyServerTransport的私有常量pipelineFactory
242                                pipelineFactory.set(nettyServerTransport, factoryWithCORS);
243                        } catch (Exception e) {
244                                Throwables.throwIfUnchecked(e);
245                                throw new RuntimeException(e);
246                        }
247                }
248        }
249        /** 
250         * 返回注释{@link ThriftService}定义的服务名称
251         * @see  {@link ThriftServiceMetadata#getThriftServiceAnnotation(Class)}
252         */
253        private static final String getServiceName(Class<?> serviceClass){
254                ThriftService thriftService = ThriftServiceMetadata.getThriftServiceAnnotation(
255                                checkNotNull(serviceClass,"serviceClass is null"));
256                return Strings.isNullOrEmpty(thriftService.value())
257                                ? serviceClass.getSimpleName()
258                                : thriftService.value();
259        }
260        /** @see #getServiceName(Class) */
261        private static final String getServiceName(Object serviceInstance){
262                return getServiceName(serviceInstance.getClass());
263        }
264        @Override
265        protected String serviceName() {
266                return this.serviceName;
267        }
268
269        @Override
270        protected final void startUp() throws Exception {
271                thriftServer.start();
272                logger.info("{} service is running(服务启动)",serviceName());
273        }
274        @Override
275        protected final void shutDown() {
276                logger.info(" {} service shutdown(服务关闭) ",      serviceName());
277                thriftServer.close();
278        }
279        /** log 输出{@code config}中的关键参数 */
280        public static final void logThriftServerConfig(ThriftServerConfig config){
281                logger.info("RPC Service Parameters(服务运行参数):");
282                logger.info("port: {}", config.getPort());
283                logger.info("connectionLimit: {}", config.getConnectionLimit());
284                logger.info("workerThreads: {}", config.getWorkerThreads());
285                logger.info("idleConnectionTimeout: {}", config.getIdleConnectionTimeout());
286        }
287}