#include <string.h>
#include <pico/stdio.h>
#include <stdio.h>
#include <hardware/spi.h>
#include <hardware/gpio.h>
#include <pico/time.h>
#include "pico/platform.h"
#include "sdcard.h"
#include "spi.h"
#include "types.h"

static inline void
sd_cs_disable(sdcard_t *sd) {
    gpio_put(sd->pin_cs, 1);
}


static void
sd_cs_enable(sdcard_t *sd) {
    gpio_put(sd->pin_cs, 0);
}


static void
sd_start_cmd(sdcard_t *sd, u8 cmd, u32 args, u8 crc) {
    sd_cs_enable(sd);
    spi_rwbyte(sd->spi, cmd | 0x40);
    spi_rwbyte(sd->spi, (args >> 24) & 0xff);
    spi_rwbyte(sd->spi, (args >> 16) & 0xff);
    spi_rwbyte(sd->spi, (args >>  8) & 0xff);
    spi_rwbyte(sd->spi, (args)       & 0xff);
    spi_rwbyte(sd->spi, crc);
}


static u8
sd_get_r1_response(sdcard_t *sd) {
    u8 ret = 0xff;
    u32 retry_cnt = 0xfff;

    do {
        ret = spi_rwbyte(sd->spi, 0xff);
    } while (0xff == ret && 0 != (retry_cnt--));
    
    return ret;
}


static u32
sd_get_r3r7_response(sdcard_t *sd) {
    u32 ret = 0;
    u8 r1;

    r1 = spi_rwbyte(sd->spi, 0xff);
    ret |= (r1 << 24);
    r1 = spi_rwbyte(sd->spi, 0xff);
    ret |= (r1 << 16);
    r1 = spi_rwbyte(sd->spi, 0xff);
    ret |= (r1 << 8);
    r1 = spi_rwbyte(sd->spi, 0xff);
    ret |= r1;

    return ret;
}


static void
sd_end_cmd(sdcard_t *sd) {
    sd_cs_disable(sd);
    for (int i=0 ; i<4 ; ++i) {
        spi_rwbyte(sd->spi, 0xff);
    }
}


static void
sd_get_ocr(sdcard_t *sd) {
    // acmd58: get OCR
    u8 resp = 0xff;
    
    while (0x00 != resp && 0x01 != resp) {
        sd_start_cmd(sd, SD_CMD58, 0, 1);
        resp = sd_get_r1_response(sd);
        sd->ocr = sd_get_r3r7_response(sd);
        sd_end_cmd(sd);
    }

    printf("OCR: 0x%08X.\n", sd->ocr);
}


static void
sd_get_csd(sdcard_t *sd) {
    i32 count;
    u8 resp;

    sd_start_cmd(sd, SD_CMD9, 0x0, 0x0);

    count = 500;
    do {
        resp = sd_get_r1_response(sd);
        --count;
    } while (0xff == resp && count > 0);

    if (0xff == resp) {
        printf("SD init timeout: CMD9.\n");
        sd_end_cmd(sd);
        return;
    }

    while (SD_START_DATA_MULTIPLE_BLOCK_READ != spi_rwbyte(sd->spi, 0xff)) {
        tight_loop_contents();
    }

    for (int i=7 ; i>=0 ; --i) {
        u64 byte = (u64)spi_rwbyte(sd->spi, 0xff);
        u32 shft = (i<<3);
        sd->csd_high |= (byte << shft);
    }

    for (int i=7 ; i>=0 ; --i) {
        u64 byte = (u64)spi_rwbyte(sd->spi, 0xff);
        u32 shft = (i<<3);
        sd->csd_low |= (byte << shft);
    }

    sd_end_cmd(sd);
    sd_end_cmd(sd);

    printf("CSD: 0x%016llx%016llx.\n", sd->csd_high, sd->csd_low);
}


void
sd_init(sdcard_t *sd) {
    u8 resp = 0xff;
    u32 r3r7_resp = 0;
    i32 count;


    // assume spi bus is initialized
    // assume spi pins are initalized

    // put sd in SPI mode - send 20 0xff
    sd_cs_enable(sd);
    for (int i=0 ; i<20 ; ++i) {
        spi_rwbyte(sd->spi, 0xff);
    }
    sd_cs_disable(sd);


    // send cmd0
    count = 500;
    do {
        sd_start_cmd(sd, SD_CMD0, 0, 0x95);
        resp = sd_get_r1_response(sd);
        sd_end_cmd(sd);
        --count;
    } while (0xff == resp && count > 0);

    if (0xff == resp) {
        printf("SD init timeout: CMD0.\n");
        return;
    }


    // discriminate card type
    sd_start_cmd(sd, SD_CMD8, 0x1aa, 0x87);
    resp = sd_get_r1_response(sd);
    r3r7_resp = sd_get_r3r7_response(sd);
    sd_end_cmd(sd);
    printf("CMD8 returns: 0x%08X.\n", r3r7_resp);
    if (0xff == resp) {
        printf("SD init timeout: CMD0.\n");
        return;
    }

    for (int i=0 ; i<10 ; ++i) {
        spi_rwbyte(sd->spi, 0xff);
    }

    if (SD_R1_ILLEGAL_CMD & resp) {
        // 1.0 card
        sd->version = 1;
    } else {
        // 2.0 card
        sd->version = 2;
    }

    // first fire cmd55 + acmd41 to ensure card is ready
    // i.e. get out of 'IDLE STATE'
    while (SD_R1_IN_IDLE_STATE & resp) {
        sd_start_cmd(sd, SD_CMD55, 0, 0);
        sd_get_r1_response(sd);
        sd_end_cmd(sd);
        sd_start_cmd(sd, SD_ACMD41, 0x40000000, 0);
        resp = sd_get_r1_response(sd);
        printf("ACMD41 returns 0x%02X.\n", (int)resp);
        sd_end_cmd(sd);
    }

    // get sdcard registers
    sd_get_ocr(sd);
    sd_get_csd(sd);

    // put some dummy bytes
    sd_end_cmd(sd);
    sd_end_cmd(sd);

    sd_cs_disable(sd);
    sd->initialized = true;

    printf("This sd card is %cc card.\n", sd_is_hc(sd) ? 'h' : 's');
    printf("This sd card has %.2f MiB capacity.\n", sd_size_byte(sd) / 1024.0 / 1024.0);
}


void
sd_read_block(sdcard_t *sd, u8 *buf, u32 block_id) {
    u8 resp = 0xff;

    if (!sd_is_hc(sd)) {
        block_id <<= 9;
    }

    u32 retry_cnt = 0xff;
    do {
        sd_start_cmd(sd, SD_CMD17, block_id, 0);
        resp = sd_get_r1_response(sd);

        if (0x00 == resp) {
            break;
        }

        if (0 == retry_cnt) {
            printf("read block fail. CMD17 timeout.\n");
            sd_end_cmd(sd);
            return;
        }

        sd_end_cmd(sd);
        --retry_cnt;
    } while (1);

    retry_cnt = 0xffff;
    do {
        resp = spi_rwbyte(sd->spi, 0xff);
        // printf("after CMD17: 0x%02X.\n", (int)resp);
        --retry_cnt;
    } while (SD_START_DATA_SINGLE_BLOCK_READ != resp && 0 != retry_cnt);

    if (SD_START_DATA_SINGLE_BLOCK_READ != resp) {
        printf("wait START tag 0xFE timeout.\n");
        sd_end_cmd(sd);
        return;
    }

    for (int i=0 ; i<512 ; ++i) {
        buf[i] = spi_rwbyte(sd->spi, 0xff);
    }

    // pretend that we have got crc
    spi_rwbyte(sd->spi, 0xff);
    spi_rwbyte(sd->spi, 0xff);

    // disable cs and send some extra dummy bytes
    sd_end_cmd(sd);
    sd_end_cmd(sd);

    return;
}


void
sd_write_block(sdcard_t *sd, const u8 *buf, u32 block_id) {
    u8 resp = 0xff;

    if (!sd_is_hc(sd)) {
        block_id <<= 9;
    }

    u32 retry_cnt = 0xff;
    do {
        sd_start_cmd(sd, SD_CMD24, block_id, 0);
        resp = sd_get_r1_response(sd);

        if (0x00 == resp) {
            break;
        }

        if (0 == retry_cnt) {
            printf("write block fail. CMD24 timeout.\n");
            sd_end_cmd(sd);
            return;
        }

        sd_end_cmd(sd);
        --retry_cnt;
    } while (1);

    spi_rwbyte(sd->spi, 0xff);
    spi_rwbyte(sd->spi, SD_START_DATA_SINGLE_BLOCK_WRITE);

    for (int i=0 ; i<512 ; ++i) {
        spi_rwbyte(sd->spi, buf[i]);
    }

    // dummy crc
    spi_rwbyte(sd->spi, 0xff);
    spi_rwbyte(sd->spi, 0xff);

    retry_cnt = 0xff;
    do {
        resp = spi_rwbyte(sd->spi, 0xff);
        --retry_cnt;
    } while (0xff == resp && retry_cnt > 0);

    if ((resp & 0x1f) != 0x5) {
        printf("write sector failed, sd card respond 0x%02x.", resp);
    }

    while (spi_rwbyte(sd->spi, 0xff) != 0xff) {
        tight_loop_contents();
    }


    // disable cs and send some extra dummy bytes
    sd_end_cmd(sd);
    sd_end_cmd(sd);

    return;
}


u64
sd_size_byte(sdcard_t *sd) {
    if (sd->version == 2) {
        u32 c_size_low  = (sd->csd_low  >> 48); // [63..48]
        u32 c_size_high = (sd->csd_high &  0x3f);  // [69..64]
        u64 c_size = c_size_low | (c_size_high << 16);

        return (c_size + 1) << 19;
    }

    u32 c_size_low  = (sd->csd_low >> 62); // [63..62]
    u32 c_size_high = (sd->csd_high & 0x3ff); // [73..64]
    u64 c_size = c_size_low | (c_size_high << 2);

    u8 c_size_mult = (sd->csd_low  >> 47) & 0x3; // [49..47]
    u8 read_bl_len = (sd->csd_high >> 20) & 0xf; // [95..84]

    return (c_size + 1) * (1 << (c_size_mult + 2)) * (1 << read_bl_len);
}
