/*
 * Copyright (c) 2018-2019, NVIDIA CORPORATION.  All rights reserved.
 *
 * All Rights Reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#pragma once

#include <stdint.h>
#include <array>
#include <fstream>
#include <iomanip>
#include <sstream>
#include <vector>

#include "TimestampConverter.h"

namespace NvTraceFormat {

struct GlobalData
{
    int64_t cpuTimestampTicksPerSecond;  // On x86, frequency of RDTSC ticks
    int64_t earliestTimestamp;  // Min of all captured timestamps for any device
    int64_t latestTimestamp;    // Max of all captured timestamps for any device
    int32_t totalEventCount;    // Total number of records captured for all devices
    uint8_t boolGpuTimestampsAlreadyConvertedToCpu; // 1 if true, 0 if false
    uint8_t padding[3];
};

using nvtrcMagic_t = std::array<char, 8>;
static const nvtrcMagic_t nvtrcVersionMagic{"nvtrc02"};
struct FileHeader
{
    nvtrcMagic_t magic;
    GlobalData global;
};

struct ArrayHeader
{
    int32_t count;
    int32_t elementSize;
};

enum class GpuCtxSwTraceError8 : uint8_t
{
    None = 0,
    UnsupportedGpu = 1,
    UnsupportedDriver = 2,
    NeedRoot = 3,
    Unknown = 255
};

struct DeviceDesc
{
    int64_t cpuTimestampStart;  // On x86, RDTSC
    int64_t gpuTimestampStart;  // NVIDIA GPU globaltimer
    int64_t cpuTimestampEnd;    // On x86, RDTSC
    int64_t gpuTimestampEnd;    // NVIDIA GPU globaltimer

    int64_t earliestTimestamp;  // Min of all captured timestamps for this device
    int64_t latestTimestamp;    // Max of all captured timestamps for this device

    uint8_t uuid[16]; // As in nvidia-smi and VkPhysicalDeviceIDProperties::deviceUUID

    char name[190];   // Null-terminated string in fixed-size buffer
    GpuCtxSwTraceError8 gpuCtxSwTraceError;
    uint8_t boolTimestampsInSortedOrder;            // 1 if true, 0 if false
};

enum class Category16 : uint16_t
{
    Invalid = 0,
    GpuContextSwitch = 1,
    Reserved0 = 2
};

enum class TypeGpuCtxSw16 : uint16_t
{
    Invalid = 0,
    ContextSwitchedIn = 1,
    ContextSwitchedOut = 2
};

struct RecordGpuCtxSw
{
    Category16      category;
    TypeGpuCtxSw16  type;
    uint32_t        processId;
    int64_t         timestamp;
    uint64_t        contextHandle;
};

struct FileData
{
    GlobalData global;
    std::vector<DeviceDesc> deviceDescs;
    std::vector<std::vector<RecordGpuCtxSw>> perDeviceData; // First index is device, second is record
};

template <class T>
inline void Read(std::ifstream& ifs, T& value, int sizeOfValue = sizeof(T))
{
    ifs.read(reinterpret_cast<char*>(&value), sizeOfValue);
}

template <class T, class Alloc>
inline bool ReadVector(std::ifstream& ifs, std::vector<T, Alloc>& buffer)
{
    ArrayHeader header;
    Read(ifs, header);
    if (!ifs) return false;

    buffer.resize(header.count);

    if (header.elementSize == sizeof(T))
    {
        ifs.read(reinterpret_cast<char*>(&buffer[0]), header.count * header.elementSize);
    }
    else if (header.elementSize > sizeof(T))
    {
        for (T& elem : buffer)
        {
            ifs.read(reinterpret_cast<char*>(&elem), sizeof(T));
        }
    }
    else
    {
        // File has older version than expected.
        // Could attempt to upconvert, but for now simply fail.
        return false;
    }

    if (!ifs) return false;

    return true;
}

template <class T>
inline void Write(std::ofstream& ofs, T& value) noexcept
{
    ofs.write(reinterpret_cast<char*>(&value), sizeof(T));
}

template <class T, class Alloc>
inline void WriteVector(std::ofstream& ofs, std::vector<T, Alloc> const& buffer) noexcept
{
    ArrayHeader header{(int32_t)buffer.size(), (int32_t)sizeof(T)};
    Write(ofs, header);

    ofs.write(reinterpret_cast<char const*>(&buffer[0]), header.count * header.elementSize);
}

inline bool ReadFileDataVersion(char const* inputFile, FileData& fileData, nvtrcMagic_t const& magic)
{
    fileData = FileData();

    std::ifstream ifs(inputFile, std::ios::in | std::ios::binary);
    if (!ifs) return false;

    FileHeader header;
    Read(ifs, header);
    if (!ifs) return false;

    if (header.magic != magic) return false;

    fileData.global = header.global;

    bool success = ReadVector(ifs, fileData.deviceDescs);
    if (!success) return false;

    fileData.perDeviceData.resize(fileData.deviceDescs.size());
    for (auto& deviceData : fileData.perDeviceData)
    {
        bool success = ReadVector(ifs, deviceData);
        if (!success) return false;
    }

    return true;
}

inline bool WriteFileDataVersion(char const* outputFile, FileData const& fileData, nvtrcMagic_t const& magic) noexcept
{
    if (fileData.deviceDescs.size() != fileData.perDeviceData.size()) return false;

    try
    {
        std::ofstream ofs(outputFile, std::ios::out | std::ios::binary);
        if (!ofs) return false;

        FileHeader header{magic, fileData.global};
        Write(ofs, header);

        WriteVector(ofs, fileData.deviceDescs);

        for (auto const& deviceData : fileData.perDeviceData)
        {
            WriteVector(ofs, deviceData);
        }

        if (!ofs) return false;
    }
    catch(...)
    {
        return false;
    }

    return true;
}

// Helper functions for reading/writing current version, but factored so
// tools can read other versions using the same header.
inline bool ReadFileData(char const* inputFile, FileData& fileData)
{
    return ReadFileDataVersion(inputFile, fileData, nvtrcVersionMagic);
}
inline bool WriteFileData(char const* outputFile, FileData const& fileData) noexcept
{
    return WriteFileDataVersion(outputFile, fileData, nvtrcVersionMagic);
}


// Helper functions for converting GPU to CPU timestamps, based on the two sync
// points for a given device.  If boolGpuTimestampsAlreadyConvertedToCpu in the
// global data is set to 0, use this helper to construct a conversion function
// for a given device.
inline TimestampConverter GpuToCpuTimestampConverter(DeviceDesc const& desc)
{
    // Source is GPU time, destination is CPU time
    return CreateTimestampConverter(
        desc.gpuTimestampStart,
        desc.gpuTimestampEnd,
        desc.cpuTimestampStart,
        desc.cpuTimestampEnd);
}

inline double SecondsElapsed(
    int64_t cpuTimestamp,
    int64_t earliestCpuTimestamp,
    int64_t cpuTimestampTicksPerSecond)
{
    return static_cast<double>(cpuTimestamp - earliestCpuTimestamp)
        / static_cast<double>(cpuTimestampTicksPerSecond);
};

inline double ToSeconds(
    int64_t cpuTimestamp,
    FileData const& fileData)
{
    return SecondsElapsed(
        cpuTimestamp,
        fileData.global.earliestTimestamp,
        fileData.global.cpuTimestampTicksPerSecond);
};


// Note that timestamps are automatically converted to CPU time unless raw GPU
// timestamps were explicitly requested.  The automatic conversion effectively
// works like this:
// for (size_t deviceIndex = 0; deviceIndex < deviceDescs.size(); ++deviceIndex)
// {
//     auto const& deviceDesc = fileData.deviceDescs[deviceIndex];
//     auto& records = fileData.perDeviceData[deviceIndex];
//
//     auto convertToCpuTime = GpuToCpuTimestampConverter(deviceDesc);
//
//     for (auto& record : records)
//         record.timestamp = convertToCpuTime(record.timestamp);
// }
//
// In the case of merging multiple FileData objects onto a single timeline, it
// is most accurate to leave all the timestamps in GPU time, and then convert
// them all afterwards using a single conversion factor.  Create this common
// converter using the start time of the earliest capture and the end time of
// the latest capture (remembering to handle this separately for each device).

inline void SetName(DeviceDesc& desc, std::string const& name)
{
    // Avoid min() macro trouble
    auto quickMin = [](size_t a, size_t b) { return a < b ? a : b; };

    // Truncate name if too long, ensuring there's a null terminator
    size_t indexOfNull = quickMin(name.size(), sizeof(desc.name) - 1);
    memcpy(&desc.name[0], name.c_str(), indexOfNull);
    desc.name[indexOfNull] = '\0';
}

// Assume uuid points to 16-byte array
inline std::string PrintableUuid(uint8_t const* uuid)
{
    std::ostringstream oss;
    oss << std::hex;

    int offset = 0;
    auto printBytes = [&](int count)
    {
        int end = offset + count;
        for (int i = offset; i < end; ++i)
        {
            // Cast to avoid potentially printing a char type as ASCII
            oss << std::setw(2) << std::setfill('0') << static_cast<uint16_t>(uuid[i]);
        }
        offset = end;
    };

    printBytes(4);
    oss << '-';
    printBytes(2);
    oss << '-';
    printBytes(2);
    oss << '-';
    printBytes(2);
    oss << '-';
    printBytes(6);

    return oss.str();
}

inline const char* GpuCtxSwTraceErrorToString(GpuCtxSwTraceError8 code, const char* defaultVal = nullptr)
{
    switch (code)
    {
    case GpuCtxSwTraceError8::None:
        return defaultVal;
    case GpuCtxSwTraceError8::UnsupportedGpu:
        return "GPU must be Pascal or newer to use GPU context switch trace";
        break;
    case GpuCtxSwTraceError8::UnsupportedDriver:
        return "Installed NVIDIA display driver does not support GPU context switch trace";
        break;
    case GpuCtxSwTraceError8::NeedRoot:
        return "Process must be running as root/Administrator to use GPU context switch trace";
        break;
    case GpuCtxSwTraceError8::Unknown: [[fallthrough]];
    default:
        return "Internal error occurred, please report to NVIDIA";
        break;
    }
}

// Use these ostream manipulators to concisely stream out hex values, prefixed with 0x
// and padded with leading zeros to be fixed width.  For example:
//     std::cout << Hex64(0xFFFF'0000'FFFF) << " " << Hex32(0xFFFF);
// writes 0x0000FFFF0000FFFF 0x0000FFFF.
template <typename Value>
class Hex
{
    Value n;
public:
    template <typename InputValue>
    Hex(InputValue n_) : n(static_cast<Value>(n_)) {}
    friend std::ostream& operator<<(std::ostream& os, Hex const& value)
    {
        // 2 digits per byte + 2 chars for "0x" prefix
        int digits = 2 * sizeof(Value) + 2;
        auto oldFlags = os.flags();
        auto oldFill = os.fill();
        auto oldWidth = os.width();
        os << std::hex << std::showbase << std::internal << std::setw(digits) << std::setfill('0');
        os << value.n;
        os << std::setw(oldWidth) << std::setfill(oldFill);
        os.flags(oldFlags);
        return os;
    }
};
using Hex64 = Hex<uint64_t>;
using Hex32 = Hex<uint32_t>;
using Hex16 = Hex<uint16_t>;
using Hex8  = Hex<uint8_t>;

// Stream out a textual representation of a FileData's Global data
inline void PrettyPrintFileDataGlobal(
    std::ostream& os,
    FileData const& fileData)
{
    using std::endl;
    auto toSec = [&fileData](int64_t cpuTimestamp){ return ToSeconds(cpuTimestamp, fileData); };

    os <<
        "CPU timestamp ticks per second: " <<
            fileData.global.cpuTimestampTicksPerSecond << endl <<
        "GPU timestamps pre-converted to CPU equivalents: " <<
            (fileData.global.boolGpuTimestampsAlreadyConvertedToCpu ? "Yes" : "No") << endl <<
        "Total number of events captured: " << fileData.global.totalEventCount << endl;

    if (fileData.global.totalEventCount > 0)
    {
        os <<
            "Range of timestamps captured:" << endl <<
            "  Earliest: " << Hex64(fileData.global.earliestTimestamp) << endl <<
            "  Latest:   " << Hex64(fileData.global.latestTimestamp) << endl;
        if (fileData.global.boolGpuTimestampsAlreadyConvertedToCpu)
        {
            os <<
                "Total duration of captured events: " <<
                    toSec(fileData.global.latestTimestamp) << " seconds" << endl;
        }
    }
}


// Stream out a textual representation of a FileData's DeviceDesc list
inline void PrettyPrintFileDataDeviceDescs(
    std::ostream& os,
    FileData const& fileData)
{
    using std::endl;
    auto toSec = [&fileData](int64_t cpuTimestamp){ return ToSeconds(cpuTimestamp, fileData); };

    for (int d = 0; d < fileData.deviceDescs.size(); ++d)
    {
        auto& desc = fileData.deviceDescs[d];
        auto& records = fileData.perDeviceData[d];

        os <<
            "Device " << d << ":" << endl <<
            "\tName: " << &desc.name[0] << endl <<
            "\tUUID: {" << PrintableUuid(&desc.uuid[0]) << "}" << endl <<
            "\tSupports GPU context-switch trace: ";

        const char* ctxswError = GpuCtxSwTraceErrorToString(desc.gpuCtxSwTraceError);
        if (!ctxswError)
        {
            os << "Yes" << endl;
        }
        else
        {
            os << "No -- " << ctxswError << endl;
        }

        os <<
            "\tTimestamps for synchronization:\n" <<
            "\t  CPU start: "  << Hex64(desc.cpuTimestampStart) <<
                " GPU start: " << Hex64(desc.gpuTimestampStart) << endl <<
            "\t  CPU end:   "  << Hex64(desc.cpuTimestampEnd) <<
                " GPU end:   " << Hex64(desc.gpuTimestampEnd) << endl <<
            "\tNumber of events captured: " << records.size() << endl;

        if (records.size() > 0)
        {
            if (fileData.global.boolGpuTimestampsAlreadyConvertedToCpu)
            {
                os <<
                    "\tRange of timestamps captured:" << endl <<
                    "\t  Earliest: " << Hex64(desc.earliestTimestamp) <<
                        " ("         << toSec(desc.earliestTimestamp) << " seconds)" << endl <<
                    "\t  Latest:   " << Hex64(desc.latestTimestamp) <<
                        " ("         << toSec(desc.latestTimestamp) << " seconds)" << endl;
            }
            else
            {
                os <<
                    "\tRange of timestamps captured:" << endl <<
                    "\t  Earliest: " << Hex64(desc.earliestTimestamp) << endl <<
                    "\t  Latest:   " << Hex64(desc.latestTimestamp) << endl;
            }
        }
    }
}

// Stream out a textual representation of a FileData's record list
inline void PrettyPrintFileDataRecords(
    std::ostream& os,
    FileData const& fileData)
{
    using std::endl;

    for (int d = 0; d < fileData.deviceDescs.size(); ++d)
    {
        auto& records = fileData.perDeviceData[d];

        os << "Device " << d << " records:" << endl;
        for (auto& record : records)
        {
            if (record.category == Category16::GpuContextSwitch)
            {
                char const* type =
                    (record.type == TypeGpuCtxSw16::ContextSwitchedIn) ? "Context Start" :
                    (record.type == TypeGpuCtxSw16::ContextSwitchedOut) ? "Context Stop" :
                    "<Other>";
                os <<
                    "\tTimestamp: " << Hex64(record.timestamp) <<
                    " | Event: " << std::left << std::setw(13) << type <<
                    " | PID: " << std::right << std::setw(5) << record.processId <<
                    " | ContextID: " << Hex32(record.contextHandle) << endl;
            }
        }
    }
}

// Stream out a textual representation of a complete FileData structure
inline void PrettyPrintFileData(
    std::ostream& os,
    FileData const& fileData,
    bool showDeviceDescs = true,
    bool showRecords = true)
{
    PrettyPrintFileDataGlobal(os, fileData);

    if (showDeviceDescs)
    {
        PrettyPrintFileDataDeviceDescs(os, fileData);
    }

    if (showRecords)
    {
        os << std::endl;
        PrettyPrintFileDataRecords(os, fileData);
    }
}

} // namespace
