1 /* 2 * Copyright (C) 2019, HuntLabs 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 module hunt.database.driver.mysql.impl.codec.InitCommandCodec; 18 19 import hunt.database.driver.mysql.impl.codec.CapabilitiesFlag; 20 import hunt.database.driver.mysql.impl.codec.CommandCodec; 21 import hunt.database.driver.mysql.impl.codec.InitialHandshakePacket; 22 import hunt.database.driver.mysql.impl.codec.MySQLEncoder; 23 import hunt.database.driver.mysql.impl.codec.Packets; 24 25 import hunt.database.driver.mysql.impl.MySQLCollation; 26 import hunt.database.driver.mysql.impl.util.BufferUtils; 27 import hunt.database.driver.mysql.impl.util.Native41Authenticator; 28 import hunt.database.base.impl.Connection; 29 import hunt.database.base.impl.command.CommandResponse; 30 import hunt.database.base.impl.command.InitCommand; 31 32 import hunt.Exceptions; 33 import hunt.logging; 34 import hunt.net.buffer.ByteBuf; 35 import hunt.collection.Map; 36 import hunt.text.Charset; 37 38 import std.algorithm; 39 import std.array; 40 import std.conv; 41 import std.string; 42 43 /** 44 * 45 */ 46 class InitCommandCodec : CommandCodec!(DbConnection, InitCommand) { 47 48 private enum int SCRAMBLE_LENGTH = 20; 49 private enum int AUTH_PLUGIN_DATA_PART1_LENGTH = 8; 50 51 private enum int ST_CONNECTING = 0; 52 private enum int ST_AUTHENTICATING = 1; 53 private enum int ST_CONNECTED = 2; 54 55 private int status = 0; 56 57 this(InitCommand cmd) { 58 super(cmd); 59 } 60 61 override 62 void decodePayload(ByteBuf payload, int payloadLength, int sequenceId) { 63 switch (status) { 64 case ST_CONNECTING: 65 decodeInit0(encoder, cmd, payload); 66 status = ST_AUTHENTICATING; 67 break; 68 case ST_AUTHENTICATING: 69 decodeInit1(cmd, payload); 70 break; 71 72 default: 73 warningf("Can't handle status: %d", status); 74 break; 75 } 76 } 77 78 private void decodeInit0(MySQLEncoder encoder, InitCommand cmd, ByteBuf payload) { 79 short protocolVersion = payload.readUnsignedByte(); 80 81 string serverVersion = BufferUtils.readNullTerminatedString(payload, StandardCharsets.US_ASCII); 82 version(HUNT_DEBUG) { 83 infof("protocolVersion: %d, serverVersion: %s", protocolVersion, serverVersion); 84 } 85 86 // we assume the server version follows ${major}.${minor}.${release} in https://dev.mysql.com/doc/refman/8.0/en/which-version.html 87 string[] versionNumbers = serverVersion.split("."); 88 int majorVersion = to!int(versionNumbers[0]); 89 int minorVersion = to!int(versionNumbers[1]); 90 // we should truncate the possible suffixes here 91 string releaseVersion = versionNumbers[2]; 92 int releaseNumber; 93 int indexOfFirstSeparator = cast(int)releaseVersion.indexOf("-"); 94 if (indexOfFirstSeparator != -1) { 95 // handle unstable release suffixes 96 string releaseNumberString = releaseVersion[0 .. indexOfFirstSeparator]; 97 releaseNumber = to!int(releaseNumberString); 98 } else { 99 releaseNumber = to!int(versionNumbers[2]); 100 } 101 if (majorVersion == 5 && (minorVersion < 7 || (minorVersion == 7 && releaseNumber < 5))) { 102 // EOF_HEADER is enabled 103 } else { 104 encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_DEPRECATE_EOF; 105 } 106 107 long connectionId = payload.readUnsignedIntLE(); 108 109 // read first part of scramble 110 byte[] scramble = new byte[SCRAMBLE_LENGTH]; 111 payload.readBytes(scramble, 0, AUTH_PLUGIN_DATA_PART1_LENGTH); 112 113 //filler 114 payload.readByte(); 115 116 // read lower 2 bytes of Capabilities flags 117 int serverCapabilitiesFlags = payload.readUnsignedShortLE(); 118 119 short characterSet = payload.readUnsignedByte(); 120 121 int statusFlags = payload.readUnsignedShortLE(); 122 123 // read upper 2 bytes of Capabilities flags 124 int capabilityFlagsUpper = payload.readUnsignedShortLE(); 125 serverCapabilitiesFlags |= (capabilityFlagsUpper << 16); 126 127 // length of the combined auth_plugin_data (scramble) 128 short lenOfAuthPluginData; 129 bool isClientPluginAuthSupported = (serverCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH) != 0; 130 if (isClientPluginAuthSupported) { 131 lenOfAuthPluginData = payload.readUnsignedByte(); 132 } else { 133 payload.readerIndex(payload.readerIndex() + 1); 134 lenOfAuthPluginData = 0; 135 } 136 137 // 10 bytes reserved 138 payload.readerIndex(payload.readerIndex() + 10); 139 140 // Rest of the plugin provided data 141 payload.readBytes(scramble, AUTH_PLUGIN_DATA_PART1_LENGTH, 142 max(SCRAMBLE_LENGTH - AUTH_PLUGIN_DATA_PART1_LENGTH, lenOfAuthPluginData - 9)); 143 payload.readByte(); // reserved byte 144 145 string authPluginName = null; 146 if (isClientPluginAuthSupported) { 147 authPluginName = BufferUtils.readNullTerminatedString(payload, StandardCharsets.UTF_8); 148 } 149 150 //TODO we may not need an extra object here?(inline) 151 // InitialHandshakePacket initialHandshakePacket = new InitialHandshakePacket(serverVersion, 152 // connectionId, 153 // serverCapabilitiesFlags, 154 // characterSet, 155 // statusFlags, 156 // scramble, 157 // authPluginName 158 // ); 159 160 bool ssl = false; 161 if (ssl) { 162 //TODO ssl 163 implementationMissing(false); 164 } else { 165 if (cmd.database() !is null && !cmd.database().empty()) { 166 encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_CONNECT_WITH_DB; 167 } 168 string authMethodName = authPluginName; // initialHandshakePacket.getAuthMethodName(); 169 byte[] serverScramble = scramble; // initialHandshakePacket.getScramble(); 170 Map!(string, string) properties = cmd.properties(); 171 MySQLCollation collation = MySQLCollation.utf8_general_ci; 172 try { 173 if(properties.containsKey("collation")) { 174 collation = MySQLCollation.valueOfName(properties.get("collation")); 175 properties.remove("collation"); 176 } else { 177 version(HUNT_DEBUG) warning(properties.toString()); 178 } 179 } catch (IllegalArgumentException e) { 180 version(HUNT_DEBUG) warning(e.msg); 181 version(HUNT_DB_DEBUG) warning(e); 182 } 183 int collationId = collation.collationId(); 184 encoder.charset = collation.mappedCharsetName(); 185 186 Map!(string, string) clientConnectionAttributes = properties; 187 if (clientConnectionAttributes !is null && !clientConnectionAttributes.isEmpty()) { 188 encoder.clientCapabilitiesFlag |= CapabilitiesFlag.CLIENT_CONNECT_ATTRS; 189 } 190 encoder.clientCapabilitiesFlag &= serverCapabilitiesFlags; // initialHandshakePacket.getServerCapabilitiesFlags(); 191 sendHandshakeResponseMessage(cmd.username(), cmd.password(), cmd.database(), 192 collationId, serverScramble, authMethodName, clientConnectionAttributes); 193 } 194 } 195 196 private void decodeInit1(InitCommand cmd, ByteBuf payload) { 197 //TODO auth switch support 198 Packets header = cast(Packets)payload.getUnsignedByte(payload.readerIndex()); 199 switch (header) { 200 case Packets.OK_PACKET_HEADER: 201 status = ST_CONNECTED; 202 if(completionHandler !is null) { 203 completionHandler(succeededResponse!(DbConnection)(cmd.connection())); 204 } 205 break; 206 case Packets.ERROR_PACKET_HEADER: 207 handleErrorPacketPayload(payload); 208 break; 209 default: 210 throw new UnsupportedOperationException(); 211 } 212 } 213 214 private void sendHandshakeResponseMessage(string username, string password, string database, 215 int collationId, byte[] serverScramble, string authMethodName, 216 Map!(string, string) clientConnectionAttributes) { 217 218 ByteBuf packet = allocateBuffer(); 219 // encode packet header 220 int packetStartIdx = packet.writerIndex(); 221 packet.writeMediumLE(0); // will set payload length later by calculation 222 packet.writeByte(sequenceId); 223 224 // encode packet payload 225 int clientCapabilitiesFlags = encoder.clientCapabilitiesFlag; 226 packet.writeIntLE(clientCapabilitiesFlags); 227 packet.writeIntLE(0xFFFFFF); 228 packet.writeByte(collationId); 229 byte[] filler = new byte[23]; 230 packet.writeBytes(filler); 231 BufferUtils.writeNullTerminatedString(packet, username, StandardCharsets.UTF_8); 232 if (password is null || password.empty()) { 233 packet.writeByte(0); 234 } else { 235 //TODO support different auth methods here 236 237 byte[] scrambledPassword = Native41Authenticator.encode(password, StandardCharsets.UTF_8, serverScramble); 238 if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) { 239 BufferUtils.writeLengthEncodedInteger(packet, scrambledPassword.length); 240 packet.writeBytes(scrambledPassword); 241 } else if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_SECURE_CONNECTION) != 0) { 242 packet.writeByte(cast(int)scrambledPassword.length); 243 packet.writeBytes(scrambledPassword); 244 } else { 245 packet.writeByte(0); 246 } 247 } 248 if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_CONNECT_WITH_DB) != 0) { 249 BufferUtils.writeNullTerminatedString(packet, database, StandardCharsets.UTF_8); 250 } 251 if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH) != 0) { 252 BufferUtils.writeNullTerminatedString(packet, authMethodName, StandardCharsets.UTF_8); 253 } 254 if ((clientCapabilitiesFlags & CapabilitiesFlag.CLIENT_CONNECT_ATTRS) != 0) { 255 ByteBuf kv = allocateBuffer(); 256 foreach (string key, string value; clientConnectionAttributes) { 257 BufferUtils.writeLengthEncodedString(kv, key, StandardCharsets.UTF_8); 258 BufferUtils.writeLengthEncodedString(kv, value, StandardCharsets.UTF_8); 259 } 260 BufferUtils.writeLengthEncodedInteger(packet, kv.readableBytes()); 261 packet.writeBytes(kv); 262 } 263 264 // set payload length 265 int payloadLength = packet.writerIndex() - packetStartIdx - 4; 266 packet.setMediumLE(packetStartIdx, payloadLength); 267 268 sendPacket(packet, payloadLength); 269 } 270 }