#include "dbserver.h" int get_nth_char_position(const char *s, int c, int n) { int charCount = 0; for (int i = 0; i < strlen(s); i++) { if (s[i] == c && (++charCount) == n) { return i; } } return -1; } int get_valid_number(char *rawNumber) { if (strlen(rawNumber) == 0) { return -1; } for (int i = 0; i < strlen(rawNumber); i++) { if (!isdigit(rawNumber[i])) { return -1; } } return atoi(rawNumber); } char *get_authkey_from_file(char *filePath) { FILE *fin = NULL; if ((fin = fopen(filePath, "r")) == NULL) { return NULL; } char *line = NULL; if ((line = read_line(fin)) == NULL) { return NULL; } fclose(fin); return line; } struct addrinfo *get_address_info(int port) { struct addrinfo *ai = 0; struct addrinfo hints; memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; hints.ai_flags = AI_PASSIVE; // Convert port to string for use in bind char portnum[6]; sprintf(portnum, "%d", port); // Get the address info int err; if ((err = getaddrinfo(NULL, portnum, &hints, &ai))) { freeaddrinfo(ai); return NULL; } return ai; } int open_socket(int port) { struct addrinfo *ai = 0; if ((ai = get_address_info(port)) == NULL) { return -1; } // Bind socket to port int recvFd = socket(AF_INET, SOCK_STREAM, 0); // Reuse ports immediately int optionValue = 1; // clang-format off if (setsockopt(recvFd, SOL_SOCKET, SO_REUSEADDR, &optionValue, sizeof(int)) < 0) { // clang-format on freeaddrinfo(ai); return -1; } // Bind fd to socket if (bind(recvFd, (struct sockaddr *)ai->ai_addr, sizeof(struct sockaddr)) < 0) { freeaddrinfo(ai); return -1; } freeaddrinfo(ai); // Backlogged if not being used right away if (listen(recvFd, MAX_BACKLOG) < 0) { return -1; } // Get port number and print to stdout struct sockaddr_in sin; socklen_t len = sizeof(sin); if (getsockname(recvFd, (struct sockaddr *)&sin, &len) == -1) { return -1; } else { fprintf(stderr, "%d\n", ntohs(sin.sin_port)); } return recvFd; } void serverstats_init() { serverStats = (struct ServerStatistics *)malloc(sizeof(struct ServerStatistics)); serverStats->currentConnections = 0; serverStats->completedConnections = 0; serverStats->authFails = 0; serverStats->gets = 0; serverStats->puts = 0; serverStats->deletes = 0; serverStats->connectionLock = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER; } int handle_listen(int port, int connections, char *authKey) { // Create a public (0) and private (1) stringstore StringStore *stringStores[2] = {NULL, NULL}; stringStores[0] = stringstore_init(); stringStores[1] = stringstore_init(); struct sockaddr_in fromAddress = {0}; socklen_t fromAddressSize = sizeof(fromAddress); int recvFd; if ((recvFd = open_socket(port)) < 0) { return EXIT_PORT_ERR; } while (1) { int fd = accept( recvFd, (struct sockaddr *)&fromAddress, &fromAddressSize); if (fd < 0) { return 1; } char hostname[NI_MAXHOST]; int err; // clang-format off if ((err = getnameinfo((struct sockaddr *)&fromAddress, fromAddressSize, hostname, NI_MAXHOST, NULL, 0, 0))) { // clang-format on return 1; } struct ConnectionInfo ci; memset(&ci, 0, sizeof(struct ConnectionInfo)); ci.fd = fd; ci.connections = connections; ci.authKey = authKey; ci.stringStores = stringStores; struct ConnectionInfo *ciPtr = malloc(sizeof(struct ConnectionInfo)); *ciPtr = ci; pthread_t threadId; pthread_create(&threadId, NULL, handle_connection, ciPtr); pthread_detach(threadId); } free(stringStores[0]); free(stringStores[1]); close(recvFd); return 0; } char *process_response_get(char *key, StringStore *stringStore) { char *response = NULL; char *value = NULL; if ((value = (char *)stringstore_retrieve(stringStore, key)) == NULL) { response = construct_HTTP_response(404, "Not Found", NULL, NULL); } else { response = construct_HTTP_response(200, "OK", NULL, value); pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->gets += 1; pthread_mutex_unlock(&(serverStats->connectionLock)); } return response; } char *process_response_put(char *key, char *value, StringStore *stringStore) { char *response = NULL; if (stringstore_add(stringStore, key, value) == 0) { response = construct_HTTP_response( 500, "Internal Server Error", NULL, NULL); } else { response = construct_HTTP_response(200, "OK", NULL, NULL); pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->puts += 1; pthread_mutex_unlock(&(serverStats->connectionLock)); } return response; } char *process_response_delete(char *key, StringStore *stringStore) { char *response = NULL; if (stringstore_delete(stringStore, key) == 0) { response = construct_HTTP_response(404, "Not Found", NULL, NULL); } else { response = construct_HTTP_response(200, "OK", NULL, NULL); pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->deletes += 1; pthread_mutex_unlock(&(serverStats->connectionLock)); } return response; } int check_auth_key(HttpHeader **headers, char *authKey) { for (int i = 0; headers[i] != NULL; i++) { if (strcmp(headers[i]->name, "Authorization") == 0 && strcmp(headers[i]->value, authKey) == 0) { return 1; } } return 0; } char *construct_response( struct RequestParams req, StringStore **stringStores) { char *response = NULL; // Default to public string store StringStore *stringStore = stringStores[0]; if (req.path == NULL || req.key == NULL) { return NULL; } // If the client is trying to access a private path if (strcmp(req.path, PRIVATE_PATH) == 0) { if (!check_auth_key(req.headers, req.authKey)) { // Client failed to auth pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->authFails += 1; pthread_mutex_unlock(&(serverStats->connectionLock)); return construct_HTTP_response(401, "Unauthorized", NULL, NULL); } else { stringStore = stringStores[1]; } } else if (strcmp(req.path, PUBLIC_PATH) != 0) { return NULL; } if (strcmp(req.method, "GET") == 0) { response = process_response_get(req.key, stringStore); } else if (strcmp(req.method, "PUT") == 0) { response = process_response_put(req.key, req.value, stringStore); } else if (strcmp(req.method, "DELETE") == 0) { response = process_response_delete(req.key, stringStore); } return response; } int get_path_key(char *address, char **path, char **key) { int keyStartPos = get_nth_char_position(address, '/', 2) + 1; if (keyStartPos <= 0) { return 1; } *path = (char *)malloc(sizeof(char) * (keyStartPos + 1)); strncpy(*path, address, keyStartPos); *key = (char *)malloc( sizeof(char) * ((strlen(address) - keyStartPos) + 1)); strcpy(*key, &address[keyStartPos]); return 0; } struct RequestParams request_params_init() { struct RequestParams requestParams = {0}; memset(&requestParams, 0, sizeof(struct RequestParams)); requestParams.method = NULL; requestParams.path = NULL; requestParams.key = NULL; requestParams.value = NULL; requestParams.authKey = NULL; requestParams.headers = NULL; memset(&(requestParams.headers), 0, sizeof(HttpHeader *)); return requestParams; } int handle_connection_request( FILE *recv, FILE *send, char *authKey, StringStore **stringStores) { struct RequestParams req = request_params_init(); char *address = NULL; // clang-format off if (get_HTTP_request(recv, &(req.method), &address, &(req.headers), &(req.value)) == 1) { // clang-format on get_path_key(address, &(req.path), &(req.key)); req.authKey = authKey; char *response = NULL; if ((response = construct_response(req, stringStores)) == NULL) { response = construct_HTTP_response(400, "Bad Request", NULL, NULL); } fprintf(send, response); fflush(send); // Will only need to free in this branch free(req.method); free(address); free(req.value); free(response); free(req.path); free(req.key); free_array_of_headers(req.headers); return 0; } return 1; } void handle_client_connection( int sendFd, char *authKey, StringStore **stringStores) { int recvFd = dup(sendFd); FILE *send = fdopen(sendFd, "w"); FILE *recv = fdopen(recvFd, "r"); // Loop until EOF or invalid response int err = 0; // clang-format off while ((err = handle_connection_request( recv, send, authKey, stringStores)) == 0) { // clang-format on ; } fclose(recv); fclose(send); close(recvFd); } void *handle_connection(void *ciPtr) { struct ConnectionInfo ci = *(struct ConnectionInfo *)ciPtr; free(ciPtr); // Lock connection count variable while updating pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->currentConnections += 1; pthread_mutex_unlock(&(serverStats->connectionLock)); int illegalConnection = 0; if (serverStats->currentConnections <= ci.connections || ci.connections == 0) { // if within the max connections handle_client_connection(ci.fd, ci.authKey, ci.stringStores); } else { // send them away write(ci.fd, UNAVALIABLE_RESPONSE, strlen(UNAVALIABLE_RESPONSE)); illegalConnection = 1; } close(ci.fd); pthread_mutex_lock(&(serverStats->connectionLock)); serverStats->currentConnections--; if (!illegalConnection) { serverStats->completedConnections++; } pthread_mutex_unlock(&(serverStats->connectionLock)); return NULL; } void print_server_stats(int s) { fprintf(stderr, "Connected clients:%d\nCompleted clients:%d\nAuth " "failures:%d\nGET " "operations:%d\nPUT operations:%d\nDELETE operations:%d\n", serverStats->currentConnections, serverStats->completedConnections, serverStats->authFails, serverStats->gets, serverStats->puts, serverStats->deletes); } void setup_signal_handler(void) { struct sigaction sa; memset(&sa, 0, sizeof(sa)); sa.sa_handler = print_server_stats; sa.sa_flags = SA_RESTART; sigaction(SIGHUP, &sa, 0); } int main(int argc, char *argv[]) { serverstats_init(); setup_signal_handler(); if (argc < 3 || argc > 4) { fprintf(stderr, PROGRAM_USAGE); exit(EXIT_ERR); } int connections = 0; if ((connections = get_valid_number(argv[2])) < 0) { fprintf(stderr, PROGRAM_USAGE); exit(EXIT_ERR); } int portnum = 0; if (argv[3] && strlen(argv[3]) > 0) { portnum = get_valid_number(argv[3]); if (portnum != 0 && (portnum < PORT_MIN || portnum > PORT_MAX)) { fprintf(stderr, PROGRAM_USAGE); exit(EXIT_ERR); } } char *authKey = get_authkey_from_file(argv[1]); if (authKey == NULL) { fprintf(stderr, "dbserver: unable to read authentication string\n"); exit(EXIT_FILE_ERR); } int err; if ((err = handle_listen(portnum, connections, authKey)) > 0) { fprintf(stderr, "dbserver: unable to open socket for listening\n"); free(authKey); exit(err); } free(authKey); free(serverStats); }