001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.amqp; 018 019import java.io.DataInput; 020import java.io.DataInputStream; 021import java.io.DataOutput; 022import java.io.DataOutputStream; 023import java.io.IOException; 024import java.io.OutputStream; 025import java.nio.ByteBuffer; 026import java.nio.channels.Channels; 027import java.nio.channels.WritableByteChannel; 028 029import org.apache.activemq.transport.amqp.message.InboundTransformer; 030import org.apache.activemq.util.ByteArrayInputStream; 031import org.apache.activemq.util.ByteArrayOutputStream; 032import org.apache.activemq.util.ByteSequence; 033import org.apache.activemq.wireformat.WireFormat; 034import org.fusesource.hawtbuf.Buffer; 035import org.slf4j.Logger; 036import org.slf4j.LoggerFactory; 037 038public class AmqpWireFormat implements WireFormat { 039 040 private static final Logger LOG = LoggerFactory.getLogger(AmqpWireFormat.class); 041 042 public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE; 043 public static final int NO_AMQP_MAX_FRAME_SIZE = -1; 044 public static final int DEFAULT_CONNECTION_TIMEOUT = 30000; 045 public static final int DEFAULT_IDLE_TIMEOUT = 30000; 046 public static final int DEFAULT_PRODUCER_CREDIT = 1000; 047 public static final boolean DEFAULT_ALLOW_NON_SASL_CONNECTIONS = false; 048 public static final int DEFAULT_ANQP_FRAME_SIZE = 128 * 1024; 049 050 private static final int SASL_PROTOCOL = 3; 051 052 private int version = 1; 053 private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE; 054 private int maxAmqpFrameSize = DEFAULT_ANQP_FRAME_SIZE; 055 private int connectAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT; 056 private int idelTimeout = DEFAULT_IDLE_TIMEOUT; 057 private int producerCredit = DEFAULT_PRODUCER_CREDIT; 058 private String transformer = InboundTransformer.TRANSFORMER_JMS; 059 private boolean allowNonSaslConnections = DEFAULT_ALLOW_NON_SASL_CONNECTIONS; 060 061 private boolean magicRead = false; 062 private ResetListener resetListener; 063 064 public interface ResetListener { 065 void onProtocolReset(); 066 } 067 068 @Override 069 public ByteSequence marshal(Object command) throws IOException { 070 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 071 DataOutputStream dos = new DataOutputStream(baos); 072 marshal(command, dos); 073 dos.close(); 074 return baos.toByteSequence(); 075 } 076 077 @Override 078 public Object unmarshal(ByteSequence packet) throws IOException { 079 ByteArrayInputStream stream = new ByteArrayInputStream(packet); 080 DataInputStream dis = new DataInputStream(stream); 081 return unmarshal(dis); 082 } 083 084 @Override 085 public void marshal(Object command, DataOutput dataOut) throws IOException { 086 if (command instanceof ByteBuffer) { 087 ByteBuffer buffer = (ByteBuffer) command; 088 089 if (dataOut instanceof OutputStream) { 090 WritableByteChannel channel = Channels.newChannel((OutputStream) dataOut); 091 channel.write(buffer); 092 } else { 093 while (buffer.hasRemaining()) { 094 dataOut.writeByte(buffer.get()); 095 } 096 } 097 } else { 098 Buffer frame = (Buffer) command; 099 frame.writeTo(dataOut); 100 } 101 } 102 103 @Override 104 public Object unmarshal(DataInput dataIn) throws IOException { 105 if (!magicRead) { 106 Buffer magic = new Buffer(8); 107 magic.readFrom(dataIn); 108 magicRead = true; 109 return new AmqpHeader(magic, false); 110 } else { 111 int size = dataIn.readInt(); 112 if (size > maxFrameSize) { 113 throw new AmqpProtocolException("Frame size exceeded max frame length."); 114 } else if (size <= 0) { 115 throw new AmqpProtocolException("Frame size value was invalid: " + size); 116 } 117 Buffer frame = new Buffer(size); 118 frame.bigEndianEditor().writeInt(size); 119 frame.readFrom(dataIn); 120 frame.clear(); 121 return frame; 122 } 123 } 124 125 /** 126 * Given an AMQP header validate that the AMQP magic is present and 127 * if so that the version and protocol values align with what we support. 128 * 129 * In the case where authentication occurs the client sends us two AMQP 130 * headers, the first being the SASL initial header which triggers the 131 * authentication process and then if that succeeds we should get a second 132 * AMQP header that does not contain the SASL protocol ID indicating the 133 * connection process should follow the normal path. We validate that the 134 * header align with these expectations. 135 * 136 * @param header 137 * the header instance received from the client. 138 * @param authenticated 139 * has the client already authenticated already. 140 * 141 * @return true if the header is valid against the current WireFormat. 142 */ 143 public boolean isHeaderValid(AmqpHeader header, boolean authenticated) { 144 if (!header.hasValidPrefix()) { 145 LOG.trace("AMQP Header arrived with invalid prefix: {}", header); 146 return false; 147 } 148 149 if (!(header.getProtocolId() == 0 || header.getProtocolId() == SASL_PROTOCOL)) { 150 LOG.trace("AMQP Header arrived with invalid protocol ID: {}", header); 151 return false; 152 } 153 154 if (!authenticated && !isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) { 155 LOG.trace("AMQP Header arrived without SASL and server requires SASL: {}", header); 156 return false; 157 } 158 159 if (header.getMajor() != 1 || header.getMinor() != 0 || header.getRevision() != 0) { 160 LOG.trace("AMQP Header arrived invalid version: {}", header); 161 return false; 162 } 163 164 return true; 165 } 166 167 /** 168 * Returns an AMQP Header object that represents the minimally protocol 169 * versions supported by this transport. A client that attempts to 170 * connect with an AMQP version that doesn't at least meat this value 171 * will receive this prior to the connection being closed. 172 * 173 * @return the minimal AMQP version needed from the client. 174 */ 175 public AmqpHeader getMinimallySupportedHeader() { 176 AmqpHeader header = new AmqpHeader(); 177 if (!isAllowNonSaslConnections()) { 178 header.setProtocolId(3); 179 } 180 181 return header; 182 } 183 184 @Override 185 public void setVersion(int version) { 186 this.version = version; 187 } 188 189 @Override 190 public int getVersion() { 191 return this.version; 192 } 193 194 public void resetMagicRead() { 195 this.magicRead = false; 196 if (resetListener != null) { 197 resetListener.onProtocolReset(); 198 } 199 } 200 201 public void setProtocolResetListener(ResetListener listener) { 202 this.resetListener = listener; 203 } 204 205 public boolean isMagicRead() { 206 return this.magicRead; 207 } 208 209 public long getMaxFrameSize() { 210 return maxFrameSize; 211 } 212 213 public void setMaxFrameSize(long maxFrameSize) { 214 this.maxFrameSize = maxFrameSize; 215 } 216 217 public int getMaxAmqpFrameSize() { 218 return maxAmqpFrameSize; 219 } 220 221 public void setMaxAmqpFrameSize(int maxAmqpFrameSize) { 222 this.maxAmqpFrameSize = maxAmqpFrameSize; 223 } 224 225 public boolean isAllowNonSaslConnections() { 226 return allowNonSaslConnections; 227 } 228 229 public void setAllowNonSaslConnections(boolean allowNonSaslConnections) { 230 this.allowNonSaslConnections = allowNonSaslConnections; 231 } 232 233 public int getConnectAttemptTimeout() { 234 return connectAttemptTimeout; 235 } 236 237 public void setConnectAttemptTimeout(int connectAttemptTimeout) { 238 this.connectAttemptTimeout = connectAttemptTimeout; 239 } 240 241 public void setProducerCredit(int producerCredit) { 242 this.producerCredit = producerCredit; 243 } 244 245 public int getProducerCredit() { 246 return producerCredit; 247 } 248 249 public String getTransformer() { 250 return transformer; 251 } 252 253 public void setTransformer(String transformer) { 254 this.transformer = transformer; 255 } 256 257 public int getIdleTimeout() { 258 return idelTimeout; 259 } 260 261 public void setIdleTimeout(int idelTimeout) { 262 this.idelTimeout = idelTimeout; 263 } 264}