Add support for correctly parsing Expressions and Statements

This commit is contained in:
Garrett Dickinson 2022-07-10 18:45:12 -05:00
parent 5cd5dda34d
commit c8b7214e6f
8 changed files with 365 additions and 65 deletions

View File

@ -10,7 +10,10 @@ import java.util.List;
public class Cobalt {
private static final Interpreter interpreter = new Interpreter();
static boolean hadError = false;
static boolean hadRuntimeError = false;
public static void main(String[] args) throws IOException {
if (args.length > 1) {
@ -28,6 +31,7 @@ public class Cobalt {
byte[] bytes = Files.readAllBytes(Paths.get(path));
run(new String(bytes, Charset.defaultCharset()));
if (hadError) System.exit(65);
if (hadRuntimeError) System.exit(70);
}
@ -49,11 +53,12 @@ public class Cobalt {
Scanner scanner = new Scanner(source);
List<Token> tokens = scanner.scanTokens();
Parser parser = new Parser(tokens);
Expr expression = parser.parse();
List<Stmt> statements = parser.parse();
if (hadError) return;
System.out.println(new AstPrinter().print(expression));
//System.out.println(new AstPrinter().print(expression));
interpreter.interpret(statements);
}
@ -77,4 +82,9 @@ public class Cobalt {
report(token.line, " at '" + token.lexeme + "'", message);
}
}
static void runtimeError(RuntimeError error) {
System.err.println("[line " + error.token.line + "] " + error.getMessage());
hadRuntimeError = true;
}
}

View File

@ -1,34 +0,0 @@
//
// Continue at Section 7.2.3
// Page 100
//
package cobalt.lang;
class Interpreter implements Expr.Visitor<Object> {
@Override
public Object visitLiteralExpr(Expr.Literal expr) {
return expr.value;
}
@Override
public Object visitUnaryExpr(Expr.Unary expr) {
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case MINUS:
return -(double)right;
}
return null;
}
@Override
public Object visitGroupingExpr(Expr.Grouping expr) {
return evaluate(expr.expression);
}
private Object evaluate(Expr expr) {
return expr.accept(this);
}
}

View File

@ -0,0 +1,194 @@
//
// Continue at Section 7.2.3
// Page 100
//
package cobalt.lang;
import java.util.List;
class Interpreter implements Expr.Visitor<Object>, Stmt.Visitor<Void> {
void interpret(List<Stmt> statements) {
try {
for (Stmt statement : statements) {
execute(statement);
}
} catch(RuntimeError error) {
Cobalt.runtimeError(error);
}
}
@Override
public Object visitLiteralExpr(Expr.Literal expr) {
return expr.value;
}
@Override
public Object visitUnaryExpr(Expr.Unary expr) {
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case BANG:
return !isTruthy(right);
case MINUS:
return -(double) right;
}
return null;
}
private boolean isTruthy(Object object) {
if (object == null)
return false;
if (object instanceof Boolean)
return (boolean) object;
return true;
}
@Override
public Object visitGroupingExpr(Expr.Grouping expr) {
return evaluate(expr.expression);
}
private Object evaluate(Expr expr) {
return expr.accept(this);
}
private void execute(Stmt stmt) {
stmt.accept(this);
}
@Override
public Void visitExpressionStmt(Stmt.Expression stmt) {
evaluate(stmt.expression);
return null;
}
@Override
public Void visitPrintStmt(Stmt.Print stmt) {
Object value = evaluate(stmt.expression);
System.out.println(stringify(value));
return null;
}
@Override
public Object visitBinaryExpr(Expr.Binary expr) {
Object left = evaluate(expr.left);
Object right = evaluate(expr.right);
switch (expr.operator.type) {
case GREATER:
checkNumberOperands(expr.operator, left, right);
return (double) left > (double) right;
case LESS:
checkNumberOperands(expr.operator, left, right);
return (double) left < (double) right;
case GREATER_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double) left >= (double) right;
case LESS_EQUAL:
checkNumberOperands(expr.operator, left, right);
return (double) left <= (double) right;
case BANG_EQUAL:
checkNumberOperands(expr.operator, left, right);
return !isEqual(left, right);
case EQUAL_EQUAL:
checkNumberOperands(expr.operator, left, right);
return isEqual(left, right);
case MINUS:
checkNumberOperand(expr.operator, right);
return (double) left - (double) right;
case PLUS:
if (left instanceof Double && right instanceof Double) {
return (double) left + (double) right;
} else if (left instanceof String && right instanceof String) {
return (String) left + (String) right;
} else if (left instanceof String && right instanceof Double) {
return (String) left + (Double) right;
} else if (left instanceof Double && right instanceof String) {
return (Double) left + (String) right;
}
throw new RuntimeError(expr.operator, "Operands must be of type number or string");
case SLASH:
checkNumberOperands(expr.operator, left, right);
checkNumberOperand(expr.operator, right);
return (double) left / (double) right;
case STAR:
checkNumberOperands(expr.operator, left, right);
boolean usingStrings = false;
String val = "";
double limit = 0;
if (left instanceof String) {
usingStrings = true;
val = (String)left;
limit = (Double)right;
} else if (right instanceof String) {
usingStrings = true;
val = (String)right;
limit = (Double)left;
}
if (usingStrings) {
String output = "";
for (int i = 0; i < limit; i++) {
output += val;
}
return output;
} else {
return (double) left * (double) right;
}
}
// Unreachable
return null;
}
private boolean isEqual(Object a, Object b) {
if (a == null && b == null)
return true;
if (a == null)
return false;
return a.equals(b);
}
private String stringify(Object object) {
if (object == null) return "nil";
if (object instanceof Double) {
String text = object.toString();
if (text.endsWith(".0")) {
text = text.substring(0, text.length() - 2);
}
return text;
}
return object.toString();
}
private void checkNumberOperand(Token operator, Object operand) {
if (operand instanceof Double) {
if (operator.type == TokenType.SLASH && (double)operand == 0) {
throw new RuntimeError(operator, "Division by zero is not allowed");
}
return;
}
throw new RuntimeError(operator, "Operand must be a number");
}
private void checkNumberOperands(Token operator, Object left, Object right) {
if (left instanceof Double && right instanceof Double) {
return;
}
if (left instanceof String && right instanceof Double) {
return;
}
if (left instanceof Double && right instanceof String) {
return;
}
throw new RuntimeError(operator, "Operands must be numbers");
}
}

View File

@ -1,6 +1,7 @@
package cobalt.lang;
import java.util.List;
import java.util.ArrayList;
import static cobalt.lang.TokenType.*;
class Parser {
@ -14,12 +15,13 @@ class Parser {
}
Expr parse() {
try {
return expression();
} catch (ParseError error) {
return null;
List<Stmt> parse() {
List<Stmt> statements = new ArrayList<>();
while (!isAtEnd()) {
statements.add(statement());
}
return statements;
}
@ -28,6 +30,26 @@ class Parser {
}
private Stmt statement() {
if (match(PRINT)) return printStatement();
return expressionStatement();
}
private Stmt printStatement() {
Expr value = expression();
consume(SEMICOLON, "Expect ';' after value.");
return new Stmt.Print(value);
}
private Stmt expressionStatement() {
Expr expr = expression();
consume(SEMICOLON, "Expect ';' after value.");
return new Stmt.Expression(expr);
}
// Parser definition for handling equalities
//
// equality -> comparison ( ( "!=" | "==" ) comparison )* ;

View File

@ -0,0 +1,10 @@
package cobalt.lang;
class RuntimeError extends RuntimeException {
final Token token;
RuntimeError(Token token, String message) {
super(message);
this.token = token;
}
}

60
cobalt/lang/Stmt.java Normal file
View File

@ -0,0 +1,60 @@
package cobalt.lang;
import java.util.List;
abstract class Stmt {
interface Visitor<R> {
//R visitBlockStmt(Block stmt);
R visitExpressionStmt(Expression stmt);
R visitPrintStmt(Print stmt);
//R visitVarStmt(Var stmt);
}
// static class Block extends Stmt {
// Block(List<Stmt> statements) {
// this.statements = statements;
// }
// <R> R accept(Visitor<R> visitor) {
// return visitor.visitBlockStmt(this);
// }
// final List<Stmt> statements;
// }
static class Expression extends Stmt {
Expression(Expr expression) {
this.expression = expression;
}
<R> R accept(Visitor<R> visitor) {
return visitor.visitExpressionStmt(this);
}
final Expr expression;
}
static class Print extends Stmt {
Print(Expr expression) {
this.expression = expression;
}
<R> R accept(Visitor<R> visitor) {
return visitor.visitPrintStmt(this);
}
final Expr expression;
}
// static class Var extends Stmt {
// Var(Token name, Expr initializer) {
// this.name = name;
// this.initializer = initializer;
// }
// <R> R accept(Visitor<R> visitor) {
// return visitor.visitVarStmt(this);
// }
// final Token name;
// final Expr initializer;
// }
abstract <R> R accept(Visitor<R> visitor);
}

View File

@ -9,16 +9,24 @@ public class GenerateAst {
public static void main(String[] args) throws IOException {
if (args.length != 1) {
System.err.println("Usage: generate_ast <output directory>");
System.exit(64);
System.exit(1);
}
String outputDir = args[0];
defineAst(outputDir, "Expr", Arrays.asList(
"Assign : Token name, Expr value",
"Binary : Expr left, Token operator, Expr right",
"Grouping : Expr expression",
"Literal : Object value",
"Unary : Token operator, Expr right"
"Unary : Token operator, Expr right",
"Variable : Token name"
));
defineAst(outputDir, "Stmt", Arrays.asList(
"Block : List<Stmt> statements",
"Expression : Expr expression",
"Print : Expr expression",
"Var : Token name, Expr initializer"
));
}
@ -26,28 +34,47 @@ public class GenerateAst {
String path = outputDir + "/" + baseName + ".java";
PrintWriter writer = new PrintWriter(path, "UTF-8");
writer.println("package cobalt.lang");
writer.println();
writer.println("package com.craftinginterpreters.lox;");
writer.println("");
writer.println("import java.util.List;");
writer.println();
writer.println("");
writer.println("abstract class " + baseName + " {");
for (String type : types){
defineVisitor(writer, baseName, types);
// The AST classes.
for (String type : types) {
String className = type.split(":")[0].trim();
String fields = type.split(":")[1].trim();
defineType(writer, baseName, className, fields);
}
// The base accept() method.
writer.println("");
writer.println(" abstract <R> R accept(Visitor<R> visitor);");
writer.println("}");
writer.close();
}
private static void defineVisitor(PrintWriter writer, String baseName, List<String> types) {
writer.println(" interface Visitor<R> {");
for (String type : types) {
String typeName = type.split(":")[0].trim();
writer.println(" R visit" + typeName + baseName + "(" + typeName + " " + baseName.toLowerCase() + ");");
}
writer.println(" }");
}
private static void defineType(PrintWriter writer, String baseName, String className, String fieldList) {
writer.println(" static class " + className + " extends " + baseName + " {");
// Constructor
// Constructor.
writer.println(" " + className + "(" + fieldList + ") {");
// Store parameters in fields.
String[] fields = fieldList.split(", ");
for (String field : fields) {
String name = field.split(" ")[1];
@ -56,7 +83,13 @@ public class GenerateAst {
writer.println(" }");
// Fields
// Visitor pattern.
writer.println();
writer.println(" <R> R accept(Visitor<R> visitor) {");
writer.println(" return visitor.visit" + className + baseName + "(this);");
writer.println(" }");
// Fields.
writer.println();
for (String field : fields) {
writer.println(" final " + field + ";");

5
examples/test.cobalt Normal file
View File

@ -0,0 +1,5 @@
print "These are statement tests for Cobalt!";
print 10 + 12;
print 13 + (14 - 7) * 3;
print "Hello, " + "World!";
print "String " + "Multiplication " * 3;