ws

Websocket frame builder and decoder (no transport)
git clone https://noulin.net/git/ws.git
Log | Files | Refs | LICENSE

ws.c (11719B)


      1 #include "ws.h"
      2 #include "libsheepyObject.h"
      3 // from libsheepy.h:
      4 // internal
      5 // u8, u16, u32, u64
      6 // randomUrandomOpen, randomWord, randomUrandomClose
      7 // bCatS, eqS
      8 // logE
      9 // range
     10 // MIN, MIN3
     11 //
     12 // from libsheepyObject.h:
     13 // findG, indexOfG
     14 // setG
     15 
     16 #include <arpa/inet.h>
     17 
     18 /* enable/disable logging */
     19 #undef pLog
     20 #define pLog(...)
     21 
     22 #define BASE64_ENCODED_SIZE(len) (len+2)/3*4+1
     23 
     24 const char *BASE64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     25 
     26 void _base64_encode_triple(u8 triple[3], char result[4]) {
     27     u32 tripleValue;
     28 
     29     tripleValue = triple[0];
     30     tripleValue *= 256;
     31     tripleValue += triple[1];
     32     tripleValue *= 256;
     33     tripleValue += triple[2];
     34 
     35     range(i,4) {
     36         result[3-i] = BASE64_CHARS[tripleValue%64];
     37         tripleValue /= 64;
     38     }
     39 }
     40 
     41 bool base64_encode(u8 *source, size_t sourcelen, char *target, size_t targetlen) {
     42 
     43     if ((sourcelen+2)/3*4 > targetlen-1)
     44        return false;
     45 
     46     while (sourcelen >= 3) {
     47         _base64_encode_triple(source, target);
     48         sourcelen -= 3;
     49         source += 3;
     50         target += 4;
     51     }
     52 
     53     if (sourcelen > 0) {
     54         u8 temp[3];
     55         memset(temp, 0, sizeof(temp));
     56         memcpy(temp, source, sourcelen);
     57         _base64_encode_triple(temp, target);
     58         target[3] = '=';
     59         if (sourcelen == 1)
     60             target[2] = '=';
     61 
     62         target += 4;
     63     }
     64 
     65     target[0] = 0;
     66 
     67     return true;
     68 }
     69 
     70 struct sha1 {
     71   u64 len;    /* processed message length */
     72   u32 h[5];   /* hash state */
     73   u8 buf[64]; /* message block buffer */
     74 };
     75 
     76 #define SHA1_DIGEST_LEN 20
     77 
     78 /* reset state */
     79 internal void sha1Init(struct sha1 *ctx);
     80 
     81 /* process message */
     82 internal void sha1Update(struct sha1 *ctx, const void *m, u64 len);
     83 
     84 /* get message digest
     85  * state is ruined after sum, keep a copy if multiple sum is needed
     86  * part of the message might be left in s, zero it if secrecy is needed
     87  */
     88 internal void sha1Final(struct sha1 *ctx, u8 md[SHA1_DIGEST_LEN]);
     89 
     90 internal u32 rol(u32 n, u8 k) { return (n << k) | (n >> (32-k)); }
     91 
     92 #define F0(b,c,d) (d ^ (b & (c ^ d)))
     93 #define F1(b,c,d) (b ^ c ^ d)
     94 #define F2(b,c,d) ((b & c) | (d & (b | c)))
     95 #define F3(b,c,d) (b ^ c ^ d)
     96 #define G0(a,b,c,d,e,i) e += rol(a,5)+F0(b,c,d)+W[i]+0x5A827999; b = rol(b,30)
     97 #define G1(a,b,c,d,e,i) e += rol(a,5)+F1(b,c,d)+W[i]+0x6ED9EBA1; b = rol(b,30)
     98 #define G2(a,b,c,d,e,i) e += rol(a,5)+F2(b,c,d)+W[i]+0x8F1BBCDC; b = rol(b,30)
     99 #define G3(a,b,c,d,e,i) e += rol(a,5)+F3(b,c,d)+W[i]+0xCA62C1D6; b = rol(b,30)
    100 
    101 internal void processblock(struct sha1 *s, const u8 *buf) {
    102   u32 W[80], a, b, c, d, e;
    103   u8 i;
    104 
    105   for (i = 0; i < 16; i++) {
    106     W[i] = (u32)buf[4*i]<<24;
    107     W[i] |= (u32)buf[4*i+1]<<16;
    108     W[i] |= (u32)buf[4*i+2]<<8;
    109     W[i] |= buf[4*i+3];
    110   }
    111   for (; i < 80; i++)
    112     W[i] = rol(W[i-3] ^ W[i-8] ^ W[i-14] ^ W[i-16], 1);
    113   a = s->h[0];
    114   b = s->h[1];
    115   c = s->h[2];
    116   d = s->h[3];
    117   e = s->h[4];
    118   for (i = 0; i < 20; ) {
    119     G0(a,b,c,d,e,i++);
    120     G0(e,a,b,c,d,i++);
    121     G0(d,e,a,b,c,i++);
    122     G0(c,d,e,a,b,i++);
    123     G0(b,c,d,e,a,i++);
    124   }
    125   while (i < 40) {
    126     G1(a,b,c,d,e,i++);
    127     G1(e,a,b,c,d,i++);
    128     G1(d,e,a,b,c,i++);
    129     G1(c,d,e,a,b,i++);
    130     G1(b,c,d,e,a,i++);
    131   }
    132   while (i < 60) {
    133     G2(a,b,c,d,e,i++);
    134     G2(e,a,b,c,d,i++);
    135     G2(d,e,a,b,c,i++);
    136     G2(c,d,e,a,b,i++);
    137     G2(b,c,d,e,a,i++);
    138   }
    139   while (i < 80) {
    140     G3(a,b,c,d,e,i++);
    141     G3(e,a,b,c,d,i++);
    142     G3(d,e,a,b,c,i++);
    143     G3(c,d,e,a,b,i++);
    144     G3(b,c,d,e,a,i++);
    145   }
    146   s->h[0] += a;
    147   s->h[1] += b;
    148   s->h[2] += c;
    149   s->h[3] += d;
    150   s->h[4] += e;
    151 }
    152 
    153 internal void pad(struct sha1 *s) {
    154   u8 r = s->len % 64;
    155 
    156   s->buf[r++] = 0x80;
    157   if (r > 56) {
    158     memset(s->buf + r, 0, 64 - r);
    159     r = 0;
    160     processblock(s, s->buf);
    161   }
    162   memset(s->buf + r, 0, 56 - r);
    163   s->len *= 8;
    164   s->buf[56] = s->len >> 56;
    165   s->buf[57] = s->len >> 48;
    166   s->buf[58] = s->len >> 40;
    167   s->buf[59] = s->len >> 32;
    168   s->buf[60] = s->len >> 24;
    169   s->buf[61] = s->len >> 16;
    170   s->buf[62] = s->len >> 8;
    171   s->buf[63] = s->len;
    172   processblock(s, s->buf);
    173 }
    174 
    175 internal void sha1Init(struct sha1 *s) {
    176   s->len = 0;
    177   s->h[0] = 0x67452301;
    178   s->h[1] = 0xEFCDAB89;
    179   s->h[2] = 0x98BADCFE;
    180   s->h[3] = 0x10325476;
    181   s->h[4] = 0xC3D2E1F0;
    182 }
    183 
    184 internal void sha1Final(struct sha1 *s,  u8 md[SHA1_DIGEST_LEN]) {
    185 
    186   pad(s);
    187   range(i,5) {
    188     md[4*i] = s->h[i] >> 24;
    189     md[4*i+1] = s->h[i] >> 16;
    190     md[4*i+2] = s->h[i] >> 8;
    191     md[4*i+3] = s->h[i];
    192   }
    193 }
    194 
    195 internal void sha1Update(struct sha1 *s,  const void *m, u64 len) {
    196   const u8 *p = m;
    197   u8 r = s->len % 64;
    198 
    199   s->len += len;
    200   if (r) {
    201     if (len < 64 - r) {
    202       memcpy(s->buf + r, p, len);
    203       return;
    204     }
    205     memcpy(s->buf + r, p, 64 - r);
    206     len -= 64 - r;
    207     p += 64 - r;
    208     processblock(s, s->buf);
    209   }
    210   for (; len >= 64; len -= 64, p += 64)
    211     processblock(s, p);
    212   memcpy(s->buf, p, len);
    213 }
    214 
    215 internal u64 keyRnd[2];
    216 internal char clientKey[BASE64_ENCODED_SIZE(sizeof(keyRnd))];
    217 
    218 internal const char handshake[] =
    219     "GET /connect HTTP/1.1\r\n"
    220     "Host: %s\r\n"
    221     "Upgrade: websocket\r\n"
    222     "Connection: Upgrade\r\n"
    223     "Sec-WebSocket-Key: %s\r\n"
    224     "Sec-WebSocket-Version: 13\r\n"
    225     "\r\n";
    226 
    227 internal const char serverResponse[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    228 
    229 size_t wsHandshakeSize(char *hostname) {
    230   return (sizeof(handshake) -1 -4) + BASE64_ENCODED_SIZE(sizeof(keyRnd))-1 + strlen(hostname) +1;
    231 }
    232 
    233 int wsHandshake(char *frame, size_t size, char *hostname) {
    234 
    235   randomUrandomOpen();
    236   keyRnd[0] = randomWord();
    237   keyRnd[1] = randomWord();
    238   randomUrandomClose();
    239 
    240   base64_encode((u8*)keyRnd, sizeof(keyRnd), clientKey, sizeof(clientKey));
    241 
    242   snprintf(frame, size, "GET /connect HTTP/1.1\r\n"
    243     "Host: %s\r\n"
    244     "Upgrade: websocket\r\n"
    245     "Connection: Upgrade\r\n"
    246     "Sec-WebSocket-Key: %s\r\n"
    247     "Sec-WebSocket-Version: 13\r\n"
    248     "\r\n", hostname, clientKey);
    249   return 0;
    250 }
    251 
    252 bool wsHanskakeCheck(char *frame, size_t size) {
    253   char *srvRes;
    254 
    255   if (srvRes = findG(frame, "sec-websocket-accept: ")) {
    256     srvRes += 22;
    257     int idx = indexOfG(srvRes, "\r\n");
    258 
    259     if ((srvRes - frame + idx) > size)
    260       return false;
    261 
    262     setG(srvRes, idx, 0);
    263 
    264     // createAcceptKey
    265     char s[128];
    266     bCatS(s, clientKey, serverResponse);
    267 
    268     struct sha1 shc;
    269     sha1Init(&shc);
    270     sha1Update(&shc, s, BASE64_ENCODED_SIZE(sizeof(keyRnd)) -1 + sizeof(serverResponse) -1);
    271     u8 digest[SHA1_DIGEST_LEN];
    272     sha1Final(&shc, digest);
    273 
    274     char base64Digest[800];
    275     base64_encode(digest, sizeof(digest), base64Digest, sizeof(base64Digest));
    276     if (!eqS(srvRes, base64Digest)) {
    277       logE("Bad server response!");
    278       return false;
    279     }
    280   }
    281   else {
    282     return false;
    283   }
    284   return true;
    285 }
    286 
    287 #define WS_FINAL_FRAME  1 << 7
    288 #define WS_OP_MASK  0xF
    289 #define WS_MASK 1 << 7
    290 #define WS_HEADER_SIZE 2
    291 #define WS_MASK_SIZE 4
    292 #define WS_LEN_MASK 0x7F
    293 
    294 #define WS_PAYLOAD_EXTEND_1 126
    295 #define WS_PAYLOAD_EXTEND_2 127
    296 
    297 #define WS_LEN_SIZE_1 2
    298 #define WS_LEN_SIZE_2 8
    299 
    300 // 1 if opcode is control frame opcode, otherwise 0
    301 #define isCtlFrame(opcode) ((opcode >> 3) & 1)
    302 
    303 ssize_t wsMaskSize(wsOpt op, size_t pSize) {
    304   if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1;
    305   if (pSize < WS_PAYLOAD_EXTEND_1)                    return pSize + WS_HEADER_SIZE + WS_MASK_SIZE;
    306   if (pSize < 1 << 16)                                return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_1;
    307   return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_2;
    308 }
    309 
    310 ssize_t wsNoMaskSize(wsOpt op, size_t pSize) {
    311   if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1;
    312   if (pSize < WS_PAYLOAD_EXTEND_1)                    return pSize + WS_HEADER_SIZE;
    313   if (pSize < 1 << 16)                                return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_1;
    314   return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_2;
    315 }
    316 
    317 ssize_t wsControlMaskSize(void) {
    318   return WS_HEADER_SIZE + WS_MASK_SIZE;
    319 }
    320 
    321 ssize_t wsControlNoMaskSize(void) {
    322   return WS_HEADER_SIZE;
    323 }
    324 
    325 bool wsMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) {
    326   #define WS_STEP1\
    327   ssize_t frameSize = wsMaskSize(op, pSize);\
    328   \
    329   if (frameSize == -1 || frameSize > size) return false;\
    330   \
    331   zeroBuf(frame, frameSize);\
    332   \
    333   if (final) frame[0] = WS_FINAL_FRAME | op;\
    334   else       frame[0] = op;
    335   WS_STEP1;
    336 
    337   u8 wsPLen, maskOffset;
    338   if (pSize < WS_PAYLOAD_EXTEND_1) {
    339     wsPLen     = pSize;
    340     maskOffset = WS_HEADER_SIZE;
    341   }
    342   else if (pSize < 1 << 16) {
    343     wsPLen     = WS_PAYLOAD_EXTEND_1;
    344     maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1;
    345     u16 *len   = (u16*)&frame[2];
    346     *len       = htons((u16)pSize);
    347   }
    348   else {
    349     wsPLen     = WS_PAYLOAD_EXTEND_2;
    350     maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2;
    351     u64 *len   = (u64*)&frame[2];
    352     *len       = htobe64(pSize);
    353 
    354   }
    355   frame[1] = WS_MASK | wsPLen;
    356   // mask key 4 bytes, random = always 0
    357 
    358   u8 payloadOffset = maskOffset + WS_MASK_SIZE;
    359 
    360   strncpy(frame+payloadOffset, payload, pSize);
    361 
    362   range(i, pSize) {
    363     frame[i + payloadOffset] ^= frame[maskOffset + (i % 4)];
    364   }
    365 
    366   return true;
    367 }
    368 
    369 bool wsNoMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) {
    370   WS_STEP1;
    371 
    372   u8 wsPLen, payloadOffset;
    373   if (pSize < WS_PAYLOAD_EXTEND_1) {
    374     wsPLen        = pSize;
    375     payloadOffset = WS_HEADER_SIZE;
    376   }
    377   else if (pSize < 1 << 16) {
    378     wsPLen        = WS_PAYLOAD_EXTEND_1;
    379     payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1;
    380     u16 *len   = (u16*)&frame[2];
    381     *len       = htons((u16)pSize);
    382   }
    383   else {
    384     wsPLen        = WS_PAYLOAD_EXTEND_2;
    385     payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2;
    386     u64 *len   = (u64*)&frame[2];
    387     *len       = htobe64(pSize);
    388 
    389   }
    390   frame[1] = wsPLen;
    391 
    392   strncpy(frame+payloadOffset, payload, pSize);
    393 
    394   return true;
    395 }
    396 
    397 #define genControlFrame(SIZE_FUNC) \
    398   ssize_t frameSize = SIZE_FUNC(op, 0);\
    399   zeroBuf(frame, frameSize);\
    400   frame[0] = WS_FINAL_FRAME | op
    401 
    402 bool wsControlMask(char *frame, wsOpt op) {
    403 
    404   if (!isCtlFrame(op)) return false;
    405   genControlFrame(wsMaskSize);
    406   frame[1] = WS_MASK;
    407 
    408   return true;
    409 }
    410 
    411 bool wsControlNoMask(char *frame, wsOpt op) {
    412 
    413   if (!isCtlFrame(op)) return false;
    414   genControlFrame(wsNoMaskSize);
    415 
    416   return true;
    417 }
    418 
    419 
    420 size_t wsDecodeSize(char *frame) {
    421   u8 len = frame[1] & WS_LEN_MASK;
    422   switch(len) {
    423     case WS_PAYLOAD_EXTEND_1:;
    424       u16 len1 = *(u16*)&frame[2];
    425       return ntohs(len1);
    426     case WS_PAYLOAD_EXTEND_2:;
    427       u64 len2 = *(u64*)&frame[2];
    428       return be64toh(len2);
    429     default:
    430       return len;
    431   }
    432 }
    433 
    434 void wsDecode(char *data, size_t size, char *frame, size_t fSize) {
    435 
    436   size_t sz     = MIN3(wsDecodeSize(frame), size, fSize);
    437   char *payload = wsDecodePayOffset(frame);
    438   strncpy(data, payload, sz);
    439 
    440   if (wsIsMasked(frame)) {
    441     char *mask = payload -4;
    442     range(i, sz) {
    443       data[i] ^= *(mask + (i % 4));
    444     }
    445   }
    446 }
    447 
    448 char *wsDecodePayOffset(char *frame) {
    449   char *r;
    450   r = frame + WS_HEADER_SIZE;
    451   u8 len = frame[1] & WS_LEN_MASK;
    452   if (len == WS_PAYLOAD_EXTEND_1) {
    453     r += WS_LEN_SIZE_1;
    454   }
    455   else if (len == WS_PAYLOAD_EXTEND_2) {
    456     r += WS_LEN_SIZE_2;
    457   }
    458   if (wsIsMasked(frame)) r += WS_MASK_SIZE;
    459   return r;
    460 }
    461 
    462 char *wsDecodeInPlace(char *frame, size_t size) {
    463 
    464   char *r = wsDecodePayOffset(frame);
    465   if (wsIsMasked(frame)) {
    466     char *mask = r -4;
    467     size_t sz  = MIN(wsDecodeSize(frame), size);
    468     range(i, sz) {
    469       *(r+i) ^= *(mask + (i % 4));
    470     }
    471   }
    472   return r;
    473 }
    474 
    475 bool wsIsFinal(char *frame) {
    476   return (frame[0] & WS_FINAL_FRAME) == WS_FINAL_FRAME;
    477 }
    478 
    479 bool wsIsMasked(char *frame) {
    480   return (frame[1] & WS_MASK) == WS_MASK;
    481 }
    482 
    483 wsOpt wsDecodeOp(char *frame) {
    484   return (wsOpt) frame[0] & WS_OP_MASK;
    485 }
    486 
    487 bool checkLibsheepyVersionWs(const char *currentLibsheepyVersion) {
    488   return eqG(currentLibsheepyVersion, LIBSHEEPY_VERSION);
    489 }
    490