Files
csse2310-a4/dbserver.c
2025-03-21 10:59:34 +10:00

444 lines
12 KiB
C

#include "dbserver.h"
int get_nth_char_position(const char *s, int c, int n) {
int charCount = 0;
for (int i = 0; i < strlen(s); i++) {
if (s[i] == c && (++charCount) == n) {
return i;
}
}
return -1;
}
int get_valid_number(char *rawNumber) {
if (strlen(rawNumber) == 0) {
return -1;
}
for (int i = 0; i < strlen(rawNumber); i++) {
if (!isdigit(rawNumber[i])) {
return -1;
}
}
return atoi(rawNumber);
}
char *get_authkey_from_file(char *filePath) {
FILE *fin = NULL;
if ((fin = fopen(filePath, "r")) == NULL) {
return NULL;
}
char *line = NULL;
if ((line = read_line(fin)) == NULL) {
return NULL;
}
fclose(fin);
return line;
}
struct addrinfo *get_address_info(int port) {
struct addrinfo *ai = 0;
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
// Convert port to string for use in bind
char portnum[6];
sprintf(portnum, "%d", port);
// Get the address info
int err;
if ((err = getaddrinfo(NULL, portnum, &hints, &ai))) {
freeaddrinfo(ai);
return NULL;
}
return ai;
}
int open_socket(int port) {
struct addrinfo *ai = 0;
if ((ai = get_address_info(port)) == NULL) {
return -1;
}
// Bind socket to port
int recvFd = socket(AF_INET, SOCK_STREAM, 0);
// Reuse ports immediately
int optionValue = 1;
// clang-format off
if (setsockopt(recvFd, SOL_SOCKET, SO_REUSEADDR, &optionValue,
sizeof(int)) < 0) {
// clang-format on
freeaddrinfo(ai);
return -1;
}
// Bind fd to socket
if (bind(recvFd, (struct sockaddr *)ai->ai_addr, sizeof(struct sockaddr)) <
0) {
freeaddrinfo(ai);
return -1;
}
freeaddrinfo(ai);
// Backlogged if not being used right away
if (listen(recvFd, MAX_BACKLOG) < 0) {
return -1;
}
// Get port number and print to stdout
struct sockaddr_in sin;
socklen_t len = sizeof(sin);
if (getsockname(recvFd, (struct sockaddr *)&sin, &len) == -1) {
return -1;
} else {
fprintf(stderr, "%d\n", ntohs(sin.sin_port));
}
return recvFd;
}
void serverstats_init() {
serverStats =
(struct ServerStatistics *)malloc(sizeof(struct ServerStatistics));
serverStats->currentConnections = 0;
serverStats->completedConnections = 0;
serverStats->authFails = 0;
serverStats->gets = 0;
serverStats->puts = 0;
serverStats->deletes = 0;
serverStats->connectionLock = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER;
}
int handle_listen(int port, int connections, char *authKey) {
// Create a public (0) and private (1) stringstore
StringStore *stringStores[2] = {NULL, NULL};
stringStores[0] = stringstore_init();
stringStores[1] = stringstore_init();
struct sockaddr_in fromAddress = {0};
socklen_t fromAddressSize = sizeof(fromAddress);
int recvFd;
if ((recvFd = open_socket(port)) < 0) {
return EXIT_PORT_ERR;
}
while (1) {
int fd = accept(
recvFd, (struct sockaddr *)&fromAddress, &fromAddressSize);
if (fd < 0) {
return 1;
}
char hostname[NI_MAXHOST];
int err;
// clang-format off
if ((err = getnameinfo((struct sockaddr *)&fromAddress,
fromAddressSize, hostname, NI_MAXHOST, NULL, 0, 0))) {
// clang-format on
return 1;
}
struct ConnectionInfo ci;
memset(&ci, 0, sizeof(struct ConnectionInfo));
ci.fd = fd;
ci.connections = connections;
ci.authKey = authKey;
ci.stringStores = stringStores;
struct ConnectionInfo *ciPtr = malloc(sizeof(struct ConnectionInfo));
*ciPtr = ci;
pthread_t threadId;
pthread_create(&threadId, NULL, handle_connection, ciPtr);
pthread_detach(threadId);
}
free(stringStores[0]);
free(stringStores[1]);
close(recvFd);
return 0;
}
char *process_response_get(char *key, StringStore *stringStore) {
char *response = NULL;
char *value = NULL;
if ((value = (char *)stringstore_retrieve(stringStore, key)) == NULL) {
response = construct_HTTP_response(404, "Not Found", NULL, NULL);
} else {
response = construct_HTTP_response(200, "OK", NULL, value);
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->gets += 1;
pthread_mutex_unlock(&(serverStats->connectionLock));
}
return response;
}
char *process_response_put(char *key, char *value, StringStore *stringStore) {
char *response = NULL;
if (stringstore_add(stringStore, key, value) == 0) {
response = construct_HTTP_response(
500, "Internal Server Error", NULL, NULL);
} else {
response = construct_HTTP_response(200, "OK", NULL, NULL);
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->puts += 1;
pthread_mutex_unlock(&(serverStats->connectionLock));
}
return response;
}
char *process_response_delete(char *key, StringStore *stringStore) {
char *response = NULL;
if (stringstore_delete(stringStore, key) == 0) {
response = construct_HTTP_response(404, "Not Found", NULL, NULL);
} else {
response = construct_HTTP_response(200, "OK", NULL, NULL);
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->deletes += 1;
pthread_mutex_unlock(&(serverStats->connectionLock));
}
return response;
}
int check_auth_key(HttpHeader **headers, char *authKey) {
for (int i = 0; headers[i] != NULL; i++) {
if (strcmp(headers[i]->name, "Authorization") == 0 &&
strcmp(headers[i]->value, authKey) == 0) {
return 1;
}
}
return 0;
}
char *construct_response(
struct RequestParams req, StringStore **stringStores) {
char *response = NULL;
// Default to public string store
StringStore *stringStore = stringStores[0];
if (req.path == NULL || req.key == NULL) {
return NULL;
}
// If the client is trying to access a private path
if (strcmp(req.path, PRIVATE_PATH) == 0) {
if (!check_auth_key(req.headers, req.authKey)) {
// Client failed to auth
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->authFails += 1;
pthread_mutex_unlock(&(serverStats->connectionLock));
return construct_HTTP_response(401, "Unauthorized", NULL, NULL);
} else {
stringStore = stringStores[1];
}
} else if (strcmp(req.path, PUBLIC_PATH) != 0) {
return NULL;
}
if (strcmp(req.method, "GET") == 0) {
response = process_response_get(req.key, stringStore);
} else if (strcmp(req.method, "PUT") == 0) {
response = process_response_put(req.key, req.value, stringStore);
} else if (strcmp(req.method, "DELETE") == 0) {
response = process_response_delete(req.key, stringStore);
}
return response;
}
int get_path_key(char *address, char **path, char **key) {
int keyStartPos = get_nth_char_position(address, '/', 2) + 1;
if (keyStartPos <= 0) {
return 1;
}
*path = (char *)malloc(sizeof(char) * (keyStartPos + 1));
strncpy(*path, address, keyStartPos);
*key = (char *)malloc(
sizeof(char) * ((strlen(address) - keyStartPos) + 1));
strcpy(*key, &address[keyStartPos]);
return 0;
}
struct RequestParams request_params_init() {
struct RequestParams requestParams = {0};
memset(&requestParams, 0, sizeof(struct RequestParams));
requestParams.method = NULL;
requestParams.path = NULL;
requestParams.key = NULL;
requestParams.value = NULL;
requestParams.authKey = NULL;
requestParams.headers = NULL;
memset(&(requestParams.headers), 0, sizeof(HttpHeader *));
return requestParams;
}
int handle_connection_request(
FILE *recv, FILE *send, char *authKey, StringStore **stringStores) {
struct RequestParams req = request_params_init();
char *address = NULL;
// clang-format off
if (get_HTTP_request(recv, &(req.method), &address, &(req.headers),
&(req.value)) == 1) {
// clang-format on
get_path_key(address, &(req.path), &(req.key));
req.authKey = authKey;
char *response = NULL;
if ((response = construct_response(req, stringStores)) == NULL) {
response = construct_HTTP_response(400, "Bad Request", NULL, NULL);
}
fprintf(send, response);
fflush(send);
// Will only need to free in this branch
free(req.method);
free(address);
free(req.value);
free(response);
free(req.path);
free(req.key);
free_array_of_headers(req.headers);
return 0;
}
return 1;
}
void handle_client_connection(
int sendFd, char *authKey, StringStore **stringStores) {
int recvFd = dup(sendFd);
FILE *send = fdopen(sendFd, "w");
FILE *recv = fdopen(recvFd, "r");
// Loop until EOF or invalid response
int err = 0;
// clang-format off
while ((err = handle_connection_request(
recv, send, authKey, stringStores)) == 0) {
// clang-format on
;
}
fclose(recv);
fclose(send);
close(recvFd);
}
void *handle_connection(void *ciPtr) {
struct ConnectionInfo ci = *(struct ConnectionInfo *)ciPtr;
free(ciPtr);
// Lock connection count variable while updating
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->currentConnections += 1;
pthread_mutex_unlock(&(serverStats->connectionLock));
int illegalConnection = 0;
if (serverStats->currentConnections <= ci.connections ||
ci.connections == 0) {
// if within the max connections
handle_client_connection(ci.fd, ci.authKey, ci.stringStores);
} else {
// send them away
write(ci.fd, UNAVALIABLE_RESPONSE, strlen(UNAVALIABLE_RESPONSE));
illegalConnection = 1;
}
close(ci.fd);
pthread_mutex_lock(&(serverStats->connectionLock));
serverStats->currentConnections--;
if (!illegalConnection) {
serverStats->completedConnections++;
}
pthread_mutex_unlock(&(serverStats->connectionLock));
return NULL;
}
void print_server_stats(int s) {
fprintf(stderr,
"Connected clients:%d\nCompleted clients:%d\nAuth "
"failures:%d\nGET "
"operations:%d\nPUT operations:%d\nDELETE operations:%d\n",
serverStats->currentConnections, serverStats->completedConnections,
serverStats->authFails, serverStats->gets, serverStats->puts,
serverStats->deletes);
}
void setup_signal_handler(void) {
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = print_server_stats;
sa.sa_flags = SA_RESTART;
sigaction(SIGHUP, &sa, 0);
}
int main(int argc, char *argv[]) {
serverstats_init();
setup_signal_handler();
if (argc < 3 || argc > 4) {
fprintf(stderr, PROGRAM_USAGE);
exit(EXIT_ERR);
}
int connections = 0;
if ((connections = get_valid_number(argv[2])) < 0) {
fprintf(stderr, PROGRAM_USAGE);
exit(EXIT_ERR);
}
int portnum = 0;
if (argv[3] && strlen(argv[3]) > 0) {
portnum = get_valid_number(argv[3]);
if (portnum != 0 && (portnum < PORT_MIN || portnum > PORT_MAX)) {
fprintf(stderr, PROGRAM_USAGE);
exit(EXIT_ERR);
}
}
char *authKey = get_authkey_from_file(argv[1]);
if (authKey == NULL) {
fprintf(stderr, "dbserver: unable to read authentication string\n");
exit(EXIT_FILE_ERR);
}
int err;
if ((err = handle_listen(portnum, connections, authKey)) > 0) {
fprintf(stderr, "dbserver: unable to open socket for listening\n");
free(authKey);
exit(err);
}
free(authKey);
free(serverStats);
}