1 /*
2  * Copyright (C) 2019, HuntLabs
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 module hunt.database.driver.mysql.impl.codec.CommandCodec;
18 
19 import hunt.database.driver.mysql.impl.codec.CapabilitiesFlag;
20 import hunt.database.driver.mysql.impl.codec.ColumnDefinition;
21 import hunt.database.driver.mysql.impl.codec.DataType;
22 import hunt.database.driver.mysql.impl.codec.MySQLEncoder;
23 // import hunt.database.driver.mysql.impl.codec.NoticeResponse;
24 import hunt.database.driver.mysql.impl.codec.Packets;
25 
26 import hunt.database.driver.mysql.MySQLException;
27 import hunt.database.driver.mysql.impl.util.BufferUtils;
28 
29 import hunt.database.base.Common;
30 import hunt.database.base.impl.command.CommandBase;
31 import hunt.database.base.impl.command.CommandResponse;
32 import hunt.database.base.impl.TxStatus;
33 
34 import hunt.net.buffer;
35 import hunt.logging;
36 import hunt.text.Charset;
37 
38 /**
39  * 
40  */
41 abstract class CommandCodecBase {
42 
43     int sequenceId;
44     MySQLEncoder encoder;
45     Throwable failure;
46 
47     // EventHandler!(NoticeResponse) noticeHandler;
48     EventHandler!(ICommandResponse) completionHandler;
49     
50     void encode(MySQLEncoder encoder);
51     void decodePayload(ByteBuf payload, int payloadLength, int sequenceId);
52 
53     ICommand getCommand();
54 }
55 
56 abstract class CommandCodec(R, C) : CommandCodecBase
57         if(is(C : CommandBase!(R))) {
58 
59     R result;
60     C cmd;
61     private ByteBufAllocator alloc;
62 
63     this(C cmd) {
64         this.cmd = cmd;
65         alloc = UnpooledByteBufAllocator.DEFAULT();
66     }
67 
68     override void encode(MySQLEncoder encoder) {
69         this.encoder = encoder;
70     }
71 
72     ByteBuf allocateBuffer() {
73         return alloc.ioBuffer();
74     }
75 
76     ByteBuf allocateBuffer(int capacity) {
77         return alloc.ioBuffer(capacity);
78     }
79 
80     void sendPacket(ByteBuf packet, int payloadLength) {
81         if (payloadLength >= Packets.PACKET_PAYLOAD_LENGTH_LIMIT) {
82             /*
83                  The original packet exceeds the limit of packet length, split the packet here.
84                  if payload length is exactly 16MBytes-1byte(0xFFFFFF), an empty packet is needed to indicate the termination.
85              */
86             sendSplitPacket(packet);
87         } else {
88             sendNonSplitPacket(packet);
89         }
90     }
91 
92     private void sendSplitPacket(ByteBuf packet) {
93         ByteBuf payload = packet.skipBytes(4);
94         while (payload.readableBytes() >= Packets.PACKET_PAYLOAD_LENGTH_LIMIT) {
95             // send a packet with 0xFFFFFF length payload
96             ByteBuf packetHeader = allocateBuffer(4);
97             packetHeader.writeMediumLE(Packets.PACKET_PAYLOAD_LENGTH_LIMIT);
98             packetHeader.writeByte(sequenceId++);
99             encoder.write(packetHeader);
100             encoder.write(payload.readRetainedSlice(Packets.PACKET_PAYLOAD_LENGTH_LIMIT));
101         }
102 
103         // send a packet with last part of the payload
104         ByteBuf packetHeader = allocateBuffer(4);
105         packetHeader.writeMediumLE(payload.readableBytes());
106         packetHeader.writeByte(sequenceId++);
107         encoder.write(packetHeader);
108         encoder.writeAndFlush(payload);
109     }
110 
111     void sendNonSplitPacket(ByteBuf packet) {
112         sequenceId++;
113         encoder.writeAndFlush(packet);
114     }
115 
116     void handleOkPacketOrErrorPacketPayload(ByteBuf payload) {
117         Packets header = cast(Packets)payload.getUnsignedByte(payload.readerIndex());
118         switch (header) {
119             case Packets.EOF_PACKET_HEADER:
120             case Packets.OK_PACKET_HEADER:
121                 if(completionHandler !is null) {
122                     completionHandler(succeededResponse(cast(Object)null));
123                 }
124                 break;
125 
126             case Packets.ERROR_PACKET_HEADER:
127                 handleErrorPacketPayload(payload);
128                 break;
129             
130             default:
131                 warning("Can't handle Packets: %d", header);
132                 break;
133         }
134     }
135 
136     void handleErrorPacketPayload(ByteBuf payload) {
137         payload.skipBytes(1); // skip ERR packet header
138         int errorCode = payload.readUnsignedShortLE();
139         string sqlState = null;
140         if ((encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_PROTOCOL_41) != 0) {
141             payload.skipBytes(1); // SQL state marker will always be #
142             sqlState = BufferUtils.readFixedLengthString(payload, 5, StandardCharsets.UTF_8);
143         }
144         string errorMessage = readRestOfPacketString(payload, StandardCharsets.UTF_8);
145 
146         if(completionHandler !is null) {
147             completionHandler(failedResponse!R(
148                     new MySQLException(errorMessage, errorCode, sqlState), TxStatus.FAILED));
149         }
150         
151     }
152 
153     OkPacket decodeOkPacketPayload(ByteBuf payload, Charset charset) {
154         payload.skipBytes(1); // skip OK packet header
155         long affectedRows = BufferUtils.readLengthEncodedInteger(payload);
156         long lastInsertId = BufferUtils.readLengthEncodedInteger(payload);
157         int serverStatusFlags = 0;
158         int numberOfWarnings = 0;
159         if ((encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_PROTOCOL_41) != 0) {
160             serverStatusFlags = payload.readUnsignedShortLE();
161             numberOfWarnings = payload.readUnsignedShortLE();
162         } else if ((encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_TRANSACTIONS) != 0) {
163             serverStatusFlags = payload.readUnsignedShortLE();
164         }
165         string statusInfo;
166         string sessionStateInfo = null;
167         if (payload.readableBytes() == 0) {
168             // handle when OK packet does not contain server status info
169             statusInfo = null;
170         } else if ((encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_SESSION_TRACK) != 0) {
171             statusInfo = BufferUtils.readLengthEncodedString(payload, charset);
172             if ((serverStatusFlags & ServerStatusFlags.SERVER_SESSION_STATE_CHANGED) != 0) {
173                 sessionStateInfo = BufferUtils.readLengthEncodedString(payload, charset);
174             }
175         } else {
176             statusInfo = readRestOfPacketString(payload, charset);
177         }
178         return new OkPacket(affectedRows, lastInsertId, serverStatusFlags, numberOfWarnings, statusInfo, sessionStateInfo);
179     }
180 
181     EofPacket decodeEofPacketPayload(ByteBuf payload) {
182         payload.skipBytes(1); // skip EOF_Packet header
183         int numberOfWarnings = payload.readUnsignedShortLE();
184         int serverStatusFlags = payload.readUnsignedShortLE();
185         return new EofPacket(numberOfWarnings, serverStatusFlags);
186     }
187 
188     string readRestOfPacketString(ByteBuf payload, Charset charset) {
189         return BufferUtils.readFixedLengthString(payload, payload.readableBytes(), charset);
190     }
191 
192     ColumnDefinition decodeColumnDefinitionPacketPayload(ByteBuf payload) {
193         string catalog = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
194         string schema = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
195         string table = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
196         string orgTable = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
197         string name = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
198         string orgName = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
199         long lengthOfFixedLengthFields = BufferUtils.readLengthEncodedInteger(payload);
200         int characterSet = payload.readUnsignedShortLE();
201         long columnLength = payload.readUnsignedIntLE();
202         DataType type = cast(DataType)payload.readUnsignedByte();
203         int flags = payload.readUnsignedShortLE();
204         byte decimals = payload.readByte();
205         return new ColumnDefinition(catalog, schema, table, orgTable, name, orgName, 
206             characterSet, columnLength, type, flags, decimals);
207     }
208 
209     void skipEofPacketIfNeeded(ByteBuf payload) {
210         if (!isDeprecatingEofFlagEnabled()) {
211             payload.skipBytes(5);
212         }
213     }
214 
215     bool isDeprecatingEofFlagEnabled() {
216         return (encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_DEPRECATE_EOF) != 0;
217     }
218     
219     override C getCommand() {
220         return cmd;
221     }
222 }