1 /*
2  * Copyright (C) 2019, HuntLabs
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 module hunt.database.driver.mysql.impl.codec.InitCommandCodec;
18 
19 import hunt.database.driver.mysql.impl.codec.CapabilitiesFlag;
20 import hunt.database.driver.mysql.impl.codec.CommandCodec;
21 import hunt.database.driver.mysql.impl.codec.InitialHandshakePacket;
22 import hunt.database.driver.mysql.impl.codec.MySQLEncoder;
23 import hunt.database.driver.mysql.impl.codec.Packets;
24 
25 import hunt.database.driver.mysql.impl.MySQLCollation;
26 import hunt.database.driver.mysql.impl.util.BufferUtils;
27 import hunt.database.driver.mysql.impl.util.Native41Authenticator;
28 import hunt.database.base.impl.Connection;
29 import hunt.database.base.impl.command.CommandResponse;
30 import hunt.database.base.impl.command.InitCommand;
31 
32 import hunt.Exceptions;
33 import hunt.logging;
34 import hunt.net.buffer.ByteBuf;
35 import hunt.collection.Map;
36 import hunt.text.Charset;
37 
38 import std.algorithm;
39 import std.array;
40 import std.conv;
41 import std.string;
42 
43 /**
44  * 
45  */
46 class InitCommandCodec : CommandCodec!(DbConnection, InitCommand) {
47 
48     private enum int SCRAMBLE_LENGTH = 20;
49     private enum int AUTH_PLUGIN_DATA_PART1_LENGTH = 8;
50 
51     private enum int ST_CONNECTING = 0;
52     private enum int ST_AUTHENTICATING = 1;
53     private enum int ST_CONNECTED = 2;
54 
55     private int status = 0;
56 
57     this(InitCommand cmd) {
58         super(cmd);
59     }
60 
61     override
62     void decodePayload(ByteBuf payload, int payloadLength, int sequenceId) {
63         switch (status) {
64             case ST_CONNECTING:
65                 decodeInit0(encoder, cmd, payload);
66                 status = ST_AUTHENTICATING;
67                 break;
68             case ST_AUTHENTICATING:
69                 decodeInit1(cmd, payload);
70                 break;
71 
72             default:
73                 warningf("Can't handle status: %d", status);
74                 break;
75         }
76     }
77 
78     private void decodeInit0(MySQLEncoder encoder, InitCommand cmd, ByteBuf payload) {
79         short protocolVersion = payload.readUnsignedByte();
80 
81         string serverVersion = BufferUtils.readNullTerminatedString(payload, StandardCharsets.US_ASCII);
82         version(HUNT_DEBUG) {
83             infof("protocolVersion: %d, serverVersion: %s", protocolVersion, serverVersion);
84         }
85 
86         // we assume the server version follows ${major}.${minor}.${release} in https://dev.mysql.com/doc/refman/8.0/en/which-version.html
87         string[] versionNumbers = serverVersion.split(".");
88         int majorVersion = to!int(versionNumbers[0]);
89         int minorVersion = to!int(versionNumbers[1]);
90         // we should truncate the possible suffixes here
91         string releaseVersion = versionNumbers[2];
92         int releaseNumber;
93         int indexOfFirstSeparator = cast(int)releaseVersion.indexOf("-");
94         if (indexOfFirstSeparator != -1) {
95             // handle unstable release suffixes
96             string releaseNumberString = releaseVersion[0 .. indexOfFirstSeparator];
97             releaseNumber = to!int(releaseNumberString);
98         } else {
99             releaseNumber = to!int(versionNumbers[2]);
100         }
101         if (majorVersion == 5 && (minorVersion < 7 || (minorVersion == 7 && releaseNumber < 5))) {
102             // EOF_HEADER is enabled
103         } else {
104             encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_DEPRECATE_EOF;
105         }
106 
107         long connectionId = payload.readUnsignedIntLE();
108 
109         // read first part of scramble
110         byte[] scramble = new byte[SCRAMBLE_LENGTH];
111         payload.readBytes(scramble, 0, AUTH_PLUGIN_DATA_PART1_LENGTH);
112 
113         //filler
114         payload.readByte();
115 
116         // read lower 2 bytes of Capabilities flags
117         int serverCapabilitiesFlags = payload.readUnsignedShortLE();
118 
119         short characterSet = payload.readUnsignedByte();
120 
121         int statusFlags = payload.readUnsignedShortLE();
122 
123         // read upper 2 bytes of Capabilities flags
124         int capabilityFlagsUpper = payload.readUnsignedShortLE();
125         serverCapabilitiesFlags |= (capabilityFlagsUpper << 16);
126 
127         // length of the combined auth_plugin_data (scramble)
128         short lenOfAuthPluginData;
129         bool isClientPluginAuthSupported = (serverCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH) != 0;
130         if (isClientPluginAuthSupported) {
131             lenOfAuthPluginData = payload.readUnsignedByte();
132         } else {
133             payload.readerIndex(payload.readerIndex() + 1);
134             lenOfAuthPluginData = 0;
135         }
136 
137         // 10 bytes reserved
138         payload.readerIndex(payload.readerIndex() + 10);
139 
140         // Rest of the plugin provided data
141         payload.readBytes(scramble, AUTH_PLUGIN_DATA_PART1_LENGTH, 
142             max(SCRAMBLE_LENGTH - AUTH_PLUGIN_DATA_PART1_LENGTH, lenOfAuthPluginData - 9));
143         payload.readByte(); // reserved byte
144 
145         string authPluginName = null;
146         if (isClientPluginAuthSupported) {
147             authPluginName = BufferUtils.readNullTerminatedString(payload, StandardCharsets.UTF_8);
148         }
149 
150         //TODO we may not need an extra object here?(inline)
151         // InitialHandshakePacket initialHandshakePacket = new InitialHandshakePacket(serverVersion,
152         //     connectionId,
153         //     serverCapabilitiesFlags,
154         //     characterSet,
155         //     statusFlags,
156         //     scramble,
157         //     authPluginName
158         // );
159 
160         bool ssl = false;
161         if (ssl) {
162             //TODO ssl
163             implementationMissing(false);
164         } else {
165             if (cmd.database() !is null && !cmd.database().empty()) {
166                 encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_CONNECT_WITH_DB;
167             }
168             string authMethodName = authPluginName; // initialHandshakePacket.getAuthMethodName();
169             byte[] serverScramble = scramble; // initialHandshakePacket.getScramble();
170             Map!(string, string) properties = cmd.properties();
171             MySQLCollation collation = MySQLCollation.utf8_general_ci;
172             try {
173                 if(properties.containsKey("collation")) {
174                     collation = MySQLCollation.valueOfName(properties.get("collation"));
175                     properties.remove("collation");
176                 } else {
177                     version(HUNT_DEBUG) warning(properties.toString());
178                 }
179             } catch (IllegalArgumentException e) {
180                 version(HUNT_DEBUG) warning(e.msg);
181                 version(HUNT_DB_DEBUG) warning(e);
182             }
183             int collationId = collation.collationId();
184             encoder.charset = collation.mappedCharsetName();
185 
186             Map!(string, string) clientConnectionAttributes = properties;
187             if (clientConnectionAttributes !is null && !clientConnectionAttributes.isEmpty()) {
188                 encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_CONNECT_ATTRS;
189             }
190             encoder.clientCapabilitiesFlag &= serverCapabilitiesFlags; // initialHandshakePacket.getServerCapabilitiesFlags();
191             sendHandshakeResponseMessage(cmd.username(), cmd.password(), cmd.database(), 
192                     collationId, serverScramble, authMethodName, clientConnectionAttributes);
193         }
194     }
195 
196     private void decodeInit1(InitCommand cmd, ByteBuf payload) {
197         //TODO auth switch support
198         Packets header = cast(Packets)payload.getUnsignedByte(payload.readerIndex());
199         switch (header) {
200             case Packets.OK_PACKET_HEADER:
201                 status = ST_CONNECTED;
202                 if(completionHandler !is null) {
203                     completionHandler(succeededResponse!(DbConnection)(cmd.connection()));
204                 }
205                 break;
206             case Packets.ERROR_PACKET_HEADER:
207                 handleErrorPacketPayload(payload);
208                 break;
209             default:
210                 throw new UnsupportedOperationException();
211         }
212     }
213 
214     private void sendHandshakeResponseMessage(string username, string password, string database, 
215             int collationId, byte[] serverScramble, string authMethodName, 
216             Map!(string, string) clientConnectionAttributes) {
217 
218         ByteBuf packet = allocateBuffer();
219         // encode packet header
220         int packetStartIdx = packet.writerIndex();
221         packet.writeMediumLE(0); // will set payload length later by calculation
222         packet.writeByte(sequenceId);
223 
224         // encode packet payload
225         int clientCapabilitiesFlags = encoder.clientCapabilitiesFlag;
226         packet.writeIntLE(clientCapabilitiesFlags);
227         packet.writeIntLE(0xFFFFFF);
228         packet.writeByte(collationId);
229         byte[] filler = new byte[23];
230         packet.writeBytes(filler);
231         BufferUtils.writeNullTerminatedString(packet, username, StandardCharsets.UTF_8);
232         if (password is null || password.empty()) {
233             packet.writeByte(0);
234         } else {
235             //TODO support different auth methods here
236 
237             byte[] scrambledPassword = Native41Authenticator.encode(password, StandardCharsets.UTF_8, serverScramble);
238             if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) {
239                 BufferUtils.writeLengthEncodedInteger(packet, scrambledPassword.length);
240                 packet.writeBytes(scrambledPassword);
241             } else if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_SECURE_CONNECTION) != 0) {
242                 packet.writeByte(cast(int)scrambledPassword.length);
243                 packet.writeBytes(scrambledPassword);
244             } else {
245                 packet.writeByte(0);
246             }
247         }
248         if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_CONNECT_WITH_DB) != 0) {
249             BufferUtils.writeNullTerminatedString(packet, database, StandardCharsets.UTF_8);
250         }
251         if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH) != 0) {
252             BufferUtils.writeNullTerminatedString(packet, authMethodName, StandardCharsets.UTF_8);
253         }
254         if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_CONNECT_ATTRS) != 0) {
255             ByteBuf kv =  allocateBuffer();
256             foreach (string key, string value; clientConnectionAttributes) {
257                 BufferUtils.writeLengthEncodedString(kv, key, StandardCharsets.UTF_8);
258                 BufferUtils.writeLengthEncodedString(kv, value, StandardCharsets.UTF_8);
259             }
260             BufferUtils.writeLengthEncodedInteger(packet, kv.readableBytes());
261             packet.writeBytes(kv);
262         }
263 
264         // set payload length
265         int payloadLength = packet.writerIndex() - packetStartIdx - 4;
266         packet.setMediumLE(packetStartIdx, payloadLength);
267 
268         sendPacket(packet, payloadLength);
269     }
270 }